diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index a67ea0334b22b..33ebebbf53991 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -995,6 +995,20 @@ struct ReshapeOpInterface bufferization::getBufferType(reshapeOp.getResult(), options); if (failed(maybeResultMemRefType)) return failure(); + + // memref.reshape requires the source buffer to have an identity layout. + // If the source memref does not have an identity layout, clone the source + // into a new buffer with an identity layout. + auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType()); + if (srcType && !srcType.getLayout().isIdentity()) { + auto identityType = + MemRefType::get(srcType.getShape(), srcType.getElementType()); + srcBuffer = rewriter + .create<bufferization::CloneOp>(op->getLoc(), + identityType, *srcBuffer) + .getResult(); + } + replaceOpWithNewBufferizedOp<memref::ReshapeOp>( rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); return success(); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 2aeb5a820812e..13d520aa40723 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -418,3 +418,24 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> { // CHECK: return %[[RESHAPED]] return %reshaped : tensor<2x2x5xf32> } + +// ----- + +// CHECK-LABEL: @reshape_with_non_identity_layout( +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>, +// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>) +func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> { + + // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>> + %extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32> + + // To satisify the constraints of memref.reshape, the subview must be cloned into + // a buffer with an identity layout. + // CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32> + // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32> + + %reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32> + + // CHECK: return %[[RESHAPED]] : memref<1x2xf32> + return %reshape : tensor<1x2xf32> +}