diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 3acc383923ca8..c0b3e5540b1df 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -86,10 +86,12 @@ def AMDGPU_ExtPackedFp8Op : Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, ConfinedAttr]>:$index)>, - Results<(outs F32:$res)> { - let summary = "Extend one of a vector of packed fp8 values to a float"; + Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> { + let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats"; + let description = [{ - Extend the value `source[index]` to a 32-bit float and return it. + Extend one or two 8-bit floats in `source[index]` to a 32-bit float or + two floats and return them. This rather unusual signature arises from the fact that AMD GPUs cannot easily work with sub 32-bit quantities, so the compiler intrinsics for @@ -97,7 +99,7 @@ def AMDGPU_ExtPackedFp8Op : this operation) take packed vectors of 4 such floats. If the passed-in vector has fewer than four elements, or the input is scalar, - the remaining values in the <4 x i8> will be filled with with + the remaining values in the <4 x i8> will be filled with undefined values as needed. }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index f194e70ee275b..9a433202e3149 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz: }]; } -def ROCDL_CvtScaleF32PkFp8F16 : +def ROCDL_CvtScaleF32PkFp8F16Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert f16 to packed fp8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed fp8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF32PkFp8Bf16 : +def ROCDL_CvtScaleF32PkFp8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed bf16 to packed fp8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed fp8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) @@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 : } -def ROCDL_CvtScaleF32PkBf8F16 : +def ROCDL_CvtScaleF32PkBf8F16Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert f16 to packed bf8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed bf8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) @@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 : } -def ROCDL_CvtScaleF32PkBf8Bf16 : +def ROCDL_CvtScaleF32PkBf8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert bf16 to packed bf8"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8. - Store the result in low/high word based on $wordSel, preserving the other word. + Scale `src` by the exponent in `scale`, then convert to packed bf8. + Store the result in low/high word of `old` based on $wordSel, preserving the other word. }]; let assemblyFormat = [{ attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF32SrFp8F16 : +def ROCDL_CvtScaleF32SrFp8F16Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>, Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 : }]; } -def ROCDL_CvtScaleF32SrBf8F16 : +def ROCDL_CvtScaleF32SrBf8F16Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>, Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f16 to packed bf8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 : }]; } -def ROCDL_CvtScaleF32SrFp8Bf16 : +def ROCDL_CvtScaleF32SrFp8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>, Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 : }]; } -def ROCDL_CvtScaleF32SrBf8Bf16: +def ROCDL_CvtScaleF32SrBf8Bf16Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>, Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. + Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding + using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. }]; let assemblyFormat = [{ @@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16: }]; } -def ROCDL_CvtScaleF32PkF16Fp8 : +def ROCDL_CvtScaleF32PkF16Fp8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert fp8 to packed f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to packed f16. + let summary = "Convert fp8 to packed f16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed f16, then scale + the packed values by the exponent in `scale`. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScaleF32PkF16Bf8 : +def ROCDL_CvtScaleF32PkF16Bf8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert bf8 to packed f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to packed f16. + let summary = "convert bf8 to packed f16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed f16, then scale + the packed values by exponent in `scale`. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScaleF16Fp8 : +def ROCDL_CvtScaleF32PkBf16Fp8Op : + ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { + let summary = "Convert fp8 to packed bf16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed bf16, then scale + the packed values by the exponent in `scale`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) + }]; +} + +def ROCDL_CvtScaleF32PkBf16Bf8Op : + ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { + let summary = "Convert bf8 to packed bf16 and scale"; + let description = [{ Convert `src` based on $wordSel to packed bf16, then scale + the packed values by the exponent in `scale`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) + }]; +} + +def ROCDL_CvtScaleF16Fp8Op : ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { let summary = "Scale and convert fp8 to f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to f16 store into the `byteSel`th byte of `old`, preserving the others. + let description = [{ Convert `src` based on $wordSel to f16, then scale the value + by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, + preserving the others. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) }]; } -def ROCDL_CvtScaleF16Bf8 : +def ROCDL_CvtScaleF16Bf8Op : ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { let summary = "Scale and convert fp8 to f16"; - let description = [{ Scale `src` based on $wordSel by the exponent in `scale` - then convert to f16 store into the `byteSel`th byte of `old`, preserving the others. + let description = [{ Convert `src` based on $wordSel to f16, then scale the value + by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, + preserving the others. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) @@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 : //===---------------------------------------------------------------------===// // 32-bit float intrinsics //===---------------------------------------------------------------------===// -def ROCDL_CvtScale32PkF32Fp8 : +def ROCDL_CvtScaleF32PkF32Fp8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed fp8 to packed f32"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp32. - Store the result in low/high word based on $wordSel, preserving the other word. + Convert `src` based on $wordSel to packed fp32, then scale the packed values by + the exponent in `scale`. Store the result in a vector. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) }]; } -def ROCDL_CvtScale32PkF32Bf8 : +def ROCDL_CvtScaleF32PkF32Bf8Op : ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>, Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert packed bf8 to packed f32"; let description = [{ - Scale `src` by the exponent in `scale` then convert to packed fp32. - Store the result in low/high word based on $wordSel, preserving the other word. + Convert `src` based on $wordSel to packed fp32, then scale the packed values by + the exponent in `scale`. Store the result in a vector. }]; let assemblyFormat = [{ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) @@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 : //===---------------------------------------------------------------------===// // 8-bit float scale intrinsics //===---------------------------------------------------------------------===// -def ROCDL_CvtScaleF32PkFp8F32: +def ROCDL_CvtScaleF32PkFp8F32Op : ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> { let summary = "Scale and convert two f32's to packed fp8"; @@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32: }]; } -def ROCDL_CvtScaleF32PkBf8F32: +def ROCDL_CvtScaleF32PkBf8F32Op : ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>, Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> { let summary = "Scale and convert two f32's to packed bf8"; @@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32: }]; } -def ROCDL_CvtScaleF32SrFp8F32: +def ROCDL_CvtScaleF32SrFp8F32Op : ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>, Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f32 to fp8 using stochastic rounding"; @@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32: } -def ROCDL_CvtScaleF32SrBf8F32: +def ROCDL_CvtScaleF32SrBf8F32Op : ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>, Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert f32 to bf8 using stochastic rounding"; @@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op : }]; } +def ROCDL_CvtPkF32Fp8Op : + ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, I1:$wordSel)> { + let summary = "Convert packed fp8 to packed f32"; + let description = [{ + Convert `src` based on $wordSel to packed fp32. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkF32Bf8Op : + ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$src, I1:$wordSel)> { + let summary = "Convert packed bf8 to packed f32"; + let description = [{ + Convert `src` based on $wordSel to packed fp32, + }]; + let assemblyFormat = [{ + attr-dict $src `[` $wordSel `]` `:` type($res) + }]; +} def ROCDL_CvtPkBf8F32Op : ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>, diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 949424db7c4d6..3acd470cff7f5 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Value source = adaptor.getSource(); auto sourceVecType = dyn_cast(op.getSource().getType()); + auto resultVecType = dyn_cast(op.getResult().getType()); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { @@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( source = longVec; } Value i32Source = rewriter.create(loc, i32, source); - Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); - if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); - } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { - rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + if (resultVecType) { + Value wordSel = createI1Constant(rewriter, loc, op.getIndex()); + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } + } else { + Value byteSel = createI32Constant(rewriter, loc, op.getIndex()); + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + byteSel); + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + byteSel); + } } return success(); } diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 27be54728c1a1..3596b3235a631 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) { return false; } -static Value castF32To(Type elementType, Value f32, Location loc, +static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter) { + Type elementType = getElementTypeOrSelf(desType); if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, elementType, f32); + return rewriter.create(loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, elementType, f32); + return rewriter.create(loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -110,6 +111,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { Value asFloat = rewriter.create( loc, rewriter.getF32Type(), in, 0); @@ -150,11 +152,20 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value inSlice = rewriter.create( loc, in, i, elemsThisOp, 1); - for (int64_t j = 0; j < elemsThisOp; ++j) { - Value asFloat = rewriter.create( - loc, rewriter.getF32Type(), inSlice, j); - Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create(loc, asType, result, i + j); + for (int64_t j = 0; j < elemsThisOp; j += 2) { + if (i + j + 1 < numElements) { // Convert two 8-bit elements + Value asFloats = rewriter.create( + loc, extResType, inSlice, j / 2); + Type desType = VectorType::get(2, outElemType); + Value asType = castF32To(desType, asFloats, loc, rewriter); + result = rewriter.create( + loc, asType, result, i + j, 1); + } else { // Convert a 8-bit element + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asType = castF32To(outElemType, asFloat, loc, rewriter); + result = rewriter.create(loc, asType, result, i + j); + } } } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir index 70775a603e54d..ea0c3afbd9021 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -9,7 +9,7 @@ // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 // CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 -// CHECK: return [[EXT]] +// CHECK: return [[EXT]] : f32 func.func @ext_scalar(%v: f8E5M2) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32 func.return %ret : f32 @@ -27,7 +27,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 { // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 -// CHECK: return [[EXT]] +// CHECK: return [[EXT]] : f32 func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32 func.return %ret : f32 @@ -39,12 +39,40 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { // CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 // CHECK: return [[EXT]] : f32 - func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32 func.return %ret : f32 } +// CHECK-LABEL: func @ext_packed_2xfp8 +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> +// CHECK: return [[EXT]] +func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32> + func.return %ret : vector<2xf32> +} + +// CHECK-LABEL: func @ext_packed_4xfp8 +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: return [[EXT]] : vector<2xf32> +func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32> + func.return %ret : vector<2xf32> +} + // CHECK-LABEL: func @packed_trunc // CHECK-SAME: ([[V:%.+]]: f32) // CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir index a313aaffdf5cc..219f822ca9a1c 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -8,7 +8,7 @@ // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 // CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 -// CHECK: return [[EXT]] +// CHECK: return [[EXT]] : f32 func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 func.return %ret : f32 @@ -26,24 +26,52 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 -// CHECK: return [[EXT]] +// CHECK: return [[EXT]] : f32 func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 func.return %ret : f32 } -// CHECK-LABEL: func @ext_full_vec( +// CHECK-LABEL: func @ext_full_vec // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 // CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 // CHECK: return [[EXT]] : f32 - func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 func.return %ret : f32 } +// CHECK-LABEL: func @ext_packed_2xfp8 +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> +// CHECK: return [[EXT]] +func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> +} + +// CHECK-LABEL: func @ext_packed_4xfp8( +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: return [[EXT]] : vector<2xf32> +func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> +} + // CHECK-LABEL: func @packed_trunc // CHECK-SAME: ([[V:%.+]]: f32) // CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir index 0e7f58c9e6749..7fb5fbfe0c89e 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s - + // CHECK-LABEL: func.func @scalar_ext // CHECK-SAME: ([[V:%.+]]: f8E5M2) // CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32 @@ -17,14 +17,9 @@ func.func @scalar_ext(%v: f8E5M2) -> f16 { // CHECK-LABEL: func.func @vector_ext_short // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>) -// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32 -// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 -// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] -// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32 -// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] -// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] -// CHECK: return [[W1]] : vector<2xf64> +// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to vector<2xf32> +// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT0]] : vector<2xf32> to vector<2xf64> +// CHECK: return [[EXT]] : vector<2xf64> func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64> @@ -35,30 +30,21 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { // CHECK-LABEL: func.func @vector_ext_long // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>) -// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: return [[W8]] +// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> +// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32> return %w : vector<9xf32> @@ -143,34 +129,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> { // ----- // CHECK-LABEL: func.func @vector_ext_long_2d -// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>) -// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN> +// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FN>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FN> to vector<11xf8E4M3FN> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FN> to vector<3xf8E4M3FN> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FN> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FN> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32> // CHECK: return [[CAST]] -func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> { - %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32> - return %w : vector<1x9xf32> +func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FN>) -> vector<1x11xf32> { + %w = arith.extf %v : vector<1x11xf8E4M3FN> to vector<1x11xf32> + return %w : vector<1x11xf32> } diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir index 6bb5b9771c015..59ed6bd95ae8b 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -28,15 +28,9 @@ func.func @vector_zero_d(%v: vector) -> vector { // CHECK-LABEL: func.func @vector_ext_short // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>) -// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32 -// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 -// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] -// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32 -// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] -// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] -// CHECK: return [[W1]] : vector<2xf64> - +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to vector<2xf32> +// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT]] : vector<2xf32> to vector<2xf64> +// CHECK: return [[EXT]] : vector<2xf64> func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64> return %w : vector<2xf64> @@ -46,30 +40,21 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { // CHECK-LABEL: func.func @vector_ext_long // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>) -// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: return [[W8]] +// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32> +// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32> +// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> +// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32> +// CHECK: return [[W5]] func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> return %w : vector<9xf32> @@ -154,34 +139,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> { // ----- // CHECK-LABEL: func.func @vector_ext_long_2d -// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>) -// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ> +// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FNUZ>) +// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FNUZ> to vector<11xf8E4M3FNUZ> // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} -// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insert [[F0]] -// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] -// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] -// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] - -// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> -// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] -// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] -// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] -// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] - -// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> -// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] -// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32> +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32> + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<3xf8E4M3FNUZ> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FNUZ> to vector<2xf32> +// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32> +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FNUZ> to f32 +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32> +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32> // CHECK: return [[CAST]] -func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> { - %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32> - return %w : vector<1x9xf32> +func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FNUZ>) -> vector<1x11xf32> { + %w = arith.extf %v : vector<1x11xf8E4M3FNUZ> to vector<1x11xf32> + return %w : vector<1x11xf32> } diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 567e6498330a3..665674f2a7873 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -4,13 +4,20 @@ // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: func @ext_packed_fp8 -// CHECK: amdgpu.ext_packed_fp8 -func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 { +// CHECK-LABEL: func @ext_packed_fp8_s +// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to f32 +func.func @ext_packed_fp8_s(%v: vector<4xf8E4M3FNUZ>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32 func.return %ret : f32 } +// CHECK-LABEL: func @ext_packed_fp8_v +// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to vector<2xf32 +func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32> + func.return %ret : vector<2xf32> +} + // CHECK-LABEL: func @packed_trunc_2xfp8 // CHECK: amdgpu.packed_trunc_2xfp8 func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> { diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index bc917041998d8..cce2c0aee62f3 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -767,10 +767,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: rocdl.cvt.scalef32.f32.fp8 // CHECK: rocdl.cvt.scalef32.pk.f16.bf8 // CHECK: rocdl.cvt.scalef32.pk.f16.fp8 +// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8 +// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8 // CHECK: rocdl.cvt.scalef32.f16.fp8 // CHECK: rocdl.cvt.scalef32.f16.bf8 // CHECK: rocdl.cvt.pk.bf8.f32 // CHECK: rocdl.cvt.pk.fp8.f32 +// CHECK: rocdl.cvt.pk.f32.bf8 +// CHECK: rocdl.cvt.pk.f32.fp8 // CHECK: rocdl.cvt.sr.bf8.f32 // CHECK: rocdl.cvt.sr.fp8.f32 // CHECK: rocdl.cvt.scalef32.sr.fp8.f32 @@ -793,10 +797,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32 %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16> %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16> + %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16> + %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16> %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32> + %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32> %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32 diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 11f2faa2761ff..e70617bfff99e 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1042,6 +1042,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false) // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false) +// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false) +// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2) @@ -1068,6 +1070,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32 %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16 %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16 + %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32 + %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32 %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32