From 0bc04751b99726ffaf30fcbdad10bf607a68d2cc Mon Sep 17 00:00:00 2001 From: cxy Date: Sun, 5 Nov 2023 17:31:44 +0800 Subject: [PATCH] [mlir] Clone simplify fails when input and result type not cast compatiable Fixed a bug that caused a cast-incompatible memref.cast operation when simplifying the clone operation. --- .../Dialect/Bufferization/IR/BufferizationOps.cpp | 5 +++++ mlir/test/Dialect/Bufferization/canonicalize.mlir | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index ca0d2f407c2d8..94bc2bcea63be 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -457,6 +457,11 @@ struct SimplifyClones : public OpRewritePattern { } Value source = cloneOp.getInput(); + if (source.getType() != cloneOp.getType() && + !memref::CastOp::areCastCompatible({source.getType()}, + {cloneOp.getType()})) + return failure(); + // Aims to find the dealloc op for the canonical source // which otherwise could prevent removal of unnecessary allocs. Value canonicalSource = source; diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir index 3ba283928a83f..3edae7827f25f 100644 --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -156,6 +156,18 @@ func.func @clone_and_cast(%arg0: memref) -> memref<32xf32> { // ----- +// CHECK-LABEL: @clone_incompatible +func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> { + %0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32> + memref.dealloc %arg0 : memref<32xf32, strided<[2]>> + return %0 : memref<32xf32> +} +// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>> +// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32> +// CHECK-NOT: memref.cast + +// ----- + // CHECK-LABEL: @alias_is_freed func.func @alias_is_freed(%arg0 : memref) { %0 = memref.cast %arg0 : memref to memref<32xf32>