diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index b08283f007078..e082e2c54ef36 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -253,10 +253,7 @@ struct DimOpInterface } }; -/// Bufferization of tensor.empty. This op does not bufferize, but we need an -/// interface implementation, so that the result of this op is considered -/// "writable" (default impl. of `isWritable`). Results of ops that do not -/// implement `BufferizableOpInterface` are not writable. +/// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor". struct EmptyOpInterface : public BufferizableOpInterface::ExternalModel { @@ -268,17 +265,21 @@ struct EmptyOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { + auto emptyOp = cast(op); + + // Optimization: Fold away the op if it has no uses. if (op->getUses().empty()) { rewriter.eraseOp(op); return success(); } - // tensor.empty ops are used to indicate the shape of a tensor. They have - // no defined contents and cannot be bufferized. However, they can be - // converted to bufferization.alloc_tensor ops, which then bufferize to an - // allocation (--empty-tensor-to-alloc-tensor). - return op->emitOpError("cannot be bufferized, but can be converted to " - "bufferization.alloc_tensor"); + // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. + FailureOr allocTensor = allocateTensorForShapedValue( + rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false); + if (failed(allocTensor)) + return failure(); + rewriter.replaceOp(op, *allocTensor); + return success(); } }; diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index c7b16315bfed1..a8b3c6af9ae89 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file | FileCheck %s // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>, @@ -62,9 +62,12 @@ func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { } // ----- + +// CHECK-LABEL: func @tensor.empty( +// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<5xf32> +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<5xf32> +// CHECK: return %[[RET]] : tensor<5xf32> func.func @tensor.empty() -> tensor<5xf32> { - // expected-error@+2 {{failed to bufferize op}} - // expected-error@+1 {{cannot be bufferized, but can be converted to bufferization.alloc_tensor}} %0 = tensor.empty() : tensor<5xf32> return %0 : tensor<5xf32> }