diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 053ea7935260a..9fbe574ec392d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -258,20 +258,23 @@ struct CallOpInterface return failure(); Value buffer = *maybeBuffer; - // Caller / callee type mismatch is handled with a CastOp. + // Caller / callee type mismatch is handled with castOrReallocMemRefValue. auto memRefType = funcType.getInput(opOperand.getOperandNumber()); // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. // If the memref type of the callee fails, introduce an extra memref.cast // that will either canonicalize away or fail compilation until we can do - // something better. + // something better. Insert a reallocation + copy if it cannot be + // statically guaranteed that a direct cast would be valid. if (buffer.getType() != memRefType) { - assert( - memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && - "CallOp::bufferize: cast incompatible"); - Value castBuffer = rewriter.create(callOp.getLoc(), - memRefType, buffer); - buffer = castBuffer; + auto memrefDstType = dyn_cast(memRefType); + assert(memrefDstType && + "buffer layout not supported on unranked tensors"); + FailureOr replacement = bufferization::castOrReallocMemRefValue( + rewriter, buffer, memrefDstType, options); + if (failed(replacement)) + return failure(); + buffer = *replacement; } newOperands.push_back(buffer); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 0248afb11f167..0d5224514e3a0 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>) // ----- +// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo( +// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> { +// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16> +// CHECK-NO-LAYOUT-MAP: } +func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> { + return %arg0 : tensor<3x8xf16> +} + +// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice( +// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> { +// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>> +// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x8xf16> +// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16> +// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16> +// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16> +// CHECK-NO-LAYOUT-MAP: } +func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) { + %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16> + %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16> + return %1 : tensor<3x8xf16> +} + +// ----- + // CHECK-LABEL: func private @private_func // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref) -> f32 func.func private @private_func(tensor) -> (f32)