diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h index 7f445fee5ba6b..78c79c915e060 100644 --- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h +++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h @@ -20,7 +20,13 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" namespace arith { -void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns); +/// Add patterns for rewriting `arith.extf` and `arith.truncf` on FP8 types +/// to wrappers around AMDGPU--specific intrinsics. If `saturateFP8TruncF` +/// is set, values outside the range of the destination type are clamped +/// to the largest value of that type instead of being rewritten to Inf (aka +/// NaN). +void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, + bool saturateFP8TruncF); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 3467e042c493e..ec0a6284fe97d 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { }]; let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"]; + + let options = [ + Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool", + /*default=*/"false", + "Use saturating truncation for 8-bit float types">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 62a84ee7903d7..402bd196f0736 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -53,6 +53,17 @@ Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast); +/// Create a constant of type `type` at location `loc` whose value is `value` +/// (an APInt or APFloat whose type must match the element type of `type`). +/// If `type` is a shaped type, create a splat constant of the given value. +/// Constants are folded if possible. +Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, + const APInt &value); +Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, + int64_t value); +Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, + const APFloat &value); + /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 7785405eae67b..c625a302a3970 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final void runOnOperation() override; }; -struct ExtfOnFloat8RewritePattern final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ExtFOnFloat8RewritePattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult match(arith::ExtFOp op) const override; void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; }; -struct TruncfToFloat8RewritePattern final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TruncFToFloat8RewritePattern final : OpRewritePattern { + bool saturateFP8 = false; + TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8) + : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {} LogicalResult match(arith::TruncFOp op) const override; void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; @@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc, llvm_unreachable("The only 32-bit float type is f32"); } -LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const { +LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { Type inType = op.getIn().getType(); if (auto inVecType = inType.dyn_cast()) { if (inVecType.isScalable()) @@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const { return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); } -void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op, +void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); @@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op, Value result = rewriter.createOrFold(loc, op.getOut().getType(), zero); if (inType.getShape().empty()) { - Value scalarIn = rewriter.create(loc, in); + Value scalarIn = + rewriter.create(loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarExt = rewriter.create(loc, outElemType, scalarIn); - result = rewriter.create(loc, scalarExt, zero); + result = rewriter.create(loc, scalarExt, zero, + ArrayRef{}); return rewriter.replaceOp(op, result); } for (int64_t i = 0; i < numElements; i += 4) { @@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op, Value asFloat = rewriter.create( loc, rewriter.getF32Type(), inSlice, j); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create( - loc, asType, result, - rewriter.createOrFold(loc, i + j)); + result = rewriter.create(loc, asType, result, i + j); } } rewriter.replaceOp(op, result); @@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { llvm_unreachable("The only 32-bit float type is f32"); } -LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const { +// If `in` is a finite value, clamp it between the maximum and minimum values +// of `outElemType` so that subsequent conversion instructions don't +// overflow those out-of-range values to NaN. These semantics are commonly +// used in machine-learning contexts where failure to clamp would lead to +// excessive NaN production. +static Value clampInput(PatternRewriter &rewriter, Location loc, + Type outElemType, Value source) { + Type sourceType = source.getType(); + const llvm::fltSemantics &sourceSem = + cast(getElementTypeOrSelf(sourceType)).getFloatSemantics(); + const llvm::fltSemantics &targetSem = + cast(outElemType).getFloatSemantics(); + + APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true); + APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false); + bool ignoredLosesInfo = false; + // We can ignore conversion failures here because this conversion promotes + // from a smaller type to a larger one - ex. there can be no loss of precision + // when casting fp8 to f16. + (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); + (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); + + Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min); + Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max); + + Value inf = createScalarOrSplatConstant( + rewriter, loc, sourceType, + APFloat::getInf(sourceSem, /*Negative=*/false)); + Value negInf = createScalarOrSplatConstant( + rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true)); + Value isInf = rewriter.createOrFold( + loc, arith::CmpFPredicate::OEQ, source, inf); + Value isNegInf = rewriter.createOrFold( + loc, arith::CmpFPredicate::OEQ, source, negInf); + Value isNan = rewriter.createOrFold( + loc, arith::CmpFPredicate::UNO, source, source); + Value isNonFinite = rewriter.create( + loc, rewriter.create(loc, isInf, isNegInf), isNan); + + Value clampedBelow = rewriter.create(loc, source, minCst); + Value clamped = rewriter.create(loc, clampedBelow, maxCst); + Value res = + rewriter.create(loc, isNonFinite, source, clamped); + return res; +} + +LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { Type outType = op.getOut().getType(); if (auto outVecType = outType.dyn_cast()) { if (outVecType.isScalable()) @@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const { return failure(); outType = outVecType.getElementType(); } + auto inType = dyn_cast(getElementTypeOrSelf(op.getIn().getType())); + if (inType && inType.getWidth() <= 8 && saturateFP8) + // Conversion between 8-bit floats is not supported with truncation enabled. + return failure(); return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); } -void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, +void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + if (saturateFP8) + in = clampInput(rewriter, loc, outElemType, in); VectorType truncResType = VectorType::get(4, outElemType); if (!in.getType().isa()) { Value asFloat = castToF32(in, loc, rewriter); Value asF8s = rewriter.create( loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create( - loc, asF8s, rewriter.createOrFold(loc, 0)); + Value result = rewriter.create(loc, asF8s, 0); return rewriter.replaceOp(op, result); } VectorType outType = op.getOut().getType().cast(); @@ -161,11 +213,13 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); Value result = rewriter.createOrFold(loc, outType, zero); if (outType.getShape().empty()) { - Value scalarIn = rewriter.create(loc, in); + Value scalarIn = + rewriter.create(loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = rewriter.create(loc, outElemType, scalarIn); - result = rewriter.create(loc, scalarTrunc, zero); + result = rewriter.create(loc, scalarTrunc, zero, + ArrayRef{}); return rewriter.replaceOp(op, result); } @@ -173,14 +227,11 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create( - loc, in, rewriter.create(loc, i + j)); + Value elemA = rewriter.create(loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create( - loc, in, - rewriter.createOrFold(loc, i + j + 1)); + Value elemB = rewriter.create(loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } thisResult = rewriter.create( @@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, } void mlir::arith::populateArithToAMDGPUConversionPatterns( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + RewritePatternSet &patterns, bool saturateFP8TruncF) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + saturateFP8TruncF); } void ArithToAMDGPUConversionPass::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - arith::populateArithToAMDGPUConversionPatterns(patterns); + arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt index 359015b6f86ad..e2c951b0b34d8 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU LINK_LIBS PUBLIC MLIRAMDGPUDialect MLIRArithDialect + MLIRArithUtils MLIRVectorDialect MLIRPass MLIRTransforms diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index 9e783c51c63d1..8a4080ea01970 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -58,35 +59,6 @@ static Type reduceInnermostDim(VectorType type) { return VectorType::get(newShape, type.getElementType()); } -/// Returns a constant of integer of vector type filled with (repeated) `value`. -static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - const APInt &value) { - TypedAttr attr; - if (dyn_cast(type)) { - attr = rewriter.getIntegerAttr(type, value); - } else { - auto vecTy = cast(type); - attr = SplatElementsAttr::get(vecTy, value); - } - - return rewriter.create(loc, attr); -} - -/// Returns a constant of integer of vector type filled with (repeated) `value`. -static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - int64_t value) { - unsigned elementBitWidth = 0; - if (auto intTy = dyn_cast(type)) - elementBitWidth = intTy.getWidth(); - else - elementBitWidth = cast(type).getElementTypeBitWidth(); - - return createScalarOrSplatConstant(rewriter, loc, type, - APInt(elementBitWidth, value)); -} - /// Extracts the `input` vector slice with elements at the last dimension offset /// by `lastOffset`. Returns a value of vector type with the last dimension /// reduced to x1 or fully scalarized, e.g.: diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 0f39c24fb917d..bf274d4ae27ed 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -197,6 +197,40 @@ mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, })); } +Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, + Type type, const APInt &value) { + TypedAttr attr; + if (isa(type)) { + attr = builder.getIntegerAttr(type, value); + } else { + auto vecTy = cast(type); + attr = SplatElementsAttr::get(vecTy, value); + } + + return builder.create(loc, attr); +} + +Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, + Type type, int64_t value) { + unsigned elementBitWidth = 0; + if (auto intTy = dyn_cast(type)) + elementBitWidth = intTy.getWidth(); + else + elementBitWidth = cast(type).getElementTypeBitWidth(); + + return createScalarOrSplatConstant(builder, loc, type, + APInt(elementBitWidth, value)); +} + +Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, + Type type, const APFloat &value) { + if (isa(type)) + return builder.createOrFold( + loc, type, builder.getFloatAttr(type, value)); + TypedAttr splat = SplatElementsAttr::get(cast(type), value); + return builder.createOrFold(loc, type, splat); +} + Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir new file mode 100644 index 0000000000000..c7f39440a349b --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt --split-input-file %s \ +// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \ +// RUN: | FileCheck %s + +// CHECK-LABEL: func.func @scalar_trunc +// CHECK-SAME: ([[V:%.+]]: f16) +// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16 +// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16 +// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16 +// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16 +// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]] +// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]] +// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]] +// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]] +// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]] +// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]] +// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]] +// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]] +// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32 +// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> +// CHECK: return [[W]] : f8E5M2FNUZ +func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ { + %w = arith.truncf %v : f16 to f8E5M2FNUZ + return %w : f8E5M2FNUZ +} + +// No 0-D test because arith.truncf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_trunc +// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> { +// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32> +// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32> +// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32> +// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32> +// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]] +// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]] +// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]] +// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]] +// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]] +// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]] +// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]] +// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]] +// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0] +// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1] +// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ> +// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ> +// CHECK: return [[W]] : vector<2xf8E4M3FNUZ> +func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> { + %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ> + return %w : vector<2xf8E4M3FNUZ> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir index a6c11d022e2c1..159a2f02f0560 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -17,14 +17,12 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { // CHECK-LABEL: func.func @vector_ext_short // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>) // CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index // CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32 // CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 -// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]] +// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] // CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32 // CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] -// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]] +// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] // CHECK: return [[W1]] : vector<2xf64> func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { @@ -38,27 +36,27 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>) // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] -// CHECK: [[W0:%.+]] = vector.insertelement [[F0]] +// CHECK: [[W0:%.+]] = vector.insert [[F0]] // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] -// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]] +// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] -// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]] +// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] -// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]] +// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> // CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] -// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]] +// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] // CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] -// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]] +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] // CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] -// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]] +// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] // CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] -// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]] +// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> // CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] -// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]] +// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] // CHECK: return [[W8]] func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> @@ -69,10 +67,9 @@ func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { // CHECK-LABEL: func.func @scalar_trunc // CHECK-SAME: ([[V:%.+]]: f16) -// CHECK: [[C0:%.+]] = arith.constant 0 : index // CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32 // CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> -// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ> // CHECK: return [[W]] : f8E5M2FNUZ func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ { %w = arith.truncf %v : f16 to f8E5M2FNUZ @@ -85,11 +82,9 @@ func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ { // CHECK-LABEL: func.func @vector_trunc_short // CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> { -// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index -// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index] +// CHECK: [[V0:%.+]] = vector.extract [[V]][0] // CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32 -// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index] +// CHECK: [[V1:%.+]] = vector.extract [[V]][1] // CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32 // CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> // CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>