diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 4449733f0daf0..77c108aab4807 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -24,7 +23,6 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" #include #include @@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefCollapseShape +//===----------------------------------------------------------------------===// + +/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given +/// that we flatten memrefs to a single dimension as part of the emulation and +/// there is no dimension to collapse any further. +struct ConvertMemRefCollapseShape final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcVal = adaptor.getSrc(); + auto newTy = dyn_cast(srcVal.getType()); + if (!newTy) + return failure(); + + if (newTy.getRank() != 1) + return failure(); + + rewriter.replaceOp(collapseShapeOp, srcVal); + return success(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( // Populate `memref.*` conversion patterns. patterns.add, - ConvertMemRefAllocation, ConvertMemRefLoad, + ConvertMemRefAllocation, + ConvertMemRefCollapseShape, ConvertMemRefLoad, ConvertMemrefStore, ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemRefReinterpretCast>( typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index fd37b7ff0a271..435dcc944778d 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () { // CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32 // CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref) -> i32 // CHECK32: return + +// ----- + +func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 { + %arr = memref.alloc() : memref<32x8x128xi4> + %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4> + %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4> + return %1 : i4 +} + +// CHECK-LABEL: func.func @memref_collapse_shape_i4( +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8> +// CHECK-NOT: memref.collapse_shape +// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8> + +// CHECK32-LABEL: func.func @memref_collapse_shape_i4( +// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32> +// CHECK32-NOT: memref.collapse_shape +// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32> +