diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index d24721f3defa6..a301b919dc523 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, return rewriter.create(loc, low, high); } +/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and +/// bitwise ops that take advantage of high-level information to avoid leaving +/// LLVM to scramble with peephole optimizations. +static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, + Value srcValue) { + VectorType srcVecType = cast(srcValue.getType()); + assert(srcVecType.getElementType().isSignlessInteger(4) && + "Expected i4 type"); + + // 1. Generate a bitcast vector -> vector. + SmallVector i8VecShape = llvm::to_vector(srcVecType.getShape()); + constexpr int64_t i4Toi8BitwidthFactor = 2; + i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor; + auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type()); + Value i8Vector = rewriter.create(loc, i8VecType, srcValue); + + // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each + // byte are placed in one vector and the high i4 elements in another vector. + constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask + auto lowBitsMaskValues = rewriter.create( + loc, DenseElementsAttr::get(i8VecType, lowBitsMask)); + Value low = rewriter.create(loc, i8VecType, i8Vector, + lowBitsMaskValues); + constexpr int8_t highBitsToShift = 4; + auto highShiftValues = rewriter.create( + loc, DenseElementsAttr::get(i8VecType, highBitsToShift)); + Value high = rewriter.create(loc, i8Vector, highShiftValues); + + // 3. Interleave low and high i8 elements. + return rewriter.create(loc, low, high); +} + /// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops /// that take advantage of high-level information to avoid leaving LLVM to /// scramble with peephole optimizations. @@ -1048,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and /// bitwise ops that take advantage of high-level information to avoid leaving -/// LLVM to scramble with peephole optimizations. +/// LLVM to scramble with peephole optimizations. Templated to choose between +/// signed and unsigned conversions. /// -/// For example: +/// For example (signed): /// arith.extsi %in : vector<8xi4> to vector<8xi32> /// is rewriten as /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> @@ -1069,16 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// %4 = vector.interleave %2, %3 : vector<4xi8> /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// -template -struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern { +/// Example (unsigned): +/// arith.extui %in : vector<8xi4> to vector<8xi32> +/// is rewritten as +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.andi %0, 15 : vector<4xi8> +/// %2 = arith.shrui %0, 4 : vector<4xi8> +/// %3 = vector.interleave %1, %2 : vector<4xi8> +/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> +/// +template +struct RewriteAlignedSubByteIntExt : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConversionOpType conversionOp, PatternRewriter &rewriter) const override { // Verify the preconditions. Value srcValue = conversionOp.getIn(); - auto srcVecType = dyn_cast(srcValue.getType()); - auto dstVecType = dyn_cast(conversionOp.getType()); + auto srcVecType = cast(srcValue.getType()); + auto dstVecType = cast(conversionOp.getType()); if (failed( commonConversionPrecondition(rewriter, dstVecType, conversionOp))) return failure(); @@ -1089,8 +1131,14 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern { return failure(); // Perform the rewrite. - Value subByteExt = - rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue); + Value subByteExt; + if (isSigned) { + subByteExt = + rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue); + } else { + subByteExt = + rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue); + } // Finalize the rewrite. rewriter.replaceOpWithNewOp( @@ -1229,10 +1277,12 @@ void vector::populateVectorNarrowTypeRewritePatterns( // Patterns for aligned cases. We set higher priority as they are expected to // generate better performance for aligned cases. - patterns.add, - RewriteAlignedSubByteIntSignedExt, + patterns.add, + RewriteAlignedSubByteIntExt, RewriteAlignedSubByteIntTrunc>(patterns.getContext(), benefit.getBenefit() + 1); + patterns.add>( + patterns.getContext(), benefit.getBenefit() + 1); } void vector::populateVectorTransposeNarrowTypeRewritePatterns( diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 8f0148119806c..614b2d4945348 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -324,6 +324,47 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { return %0 : vector<16x8xi7> } +// CHECK-LABEL: func.func @aligned_extui( +func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extui %a : vector<8xi4> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_2d( +func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_base_case( +func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> + %0 = arith.extui %a : vector<8xi4> to vector<8xi8> + return %0 : vector<8xi8> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op @@ -335,4 +376,3 @@ module attributes {transform.with_named_sequence} { transform.yield } } -