diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp index 4ded8ba55013d..24fbc1dca8361 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -56,11 +56,27 @@ struct InParallelOpInterface } }; +struct ReduceReturnOpInterface + : public BufferDeallocationOpInterface::ExternalModel< + ReduceReturnOpInterface, scf::ReduceReturnOp> { + FailureOr process(Operation *op, DeallocationState &state, + const DeallocationOptions &options) const { + auto reduceReturnOp = cast(op); + if (isa(reduceReturnOp.getOperand().getType())) + return op->emitError("only supported when operand is not a MemRef"); + + SmallVector updatedOperandOwnership; + return deallocation_impl::insertDeallocOpForReturnLike( + state, op, {}, updatedOperandOwnership); + } +}; + } // namespace void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) { InParallelOp::attachInterface(*ctx); + ReduceReturnOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir index 0847b1f1183f9..99cfed99c02d1 100644 --- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir +++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir @@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) { // CHECK: } // CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true // CHECK-NOT: retain + +// ----- + +func.func @reduce(%buffer: memref<100xf32>) { + %init = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 { + %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32> + scf.reduce(%elem_to_reduce) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %alloc = memref.alloc() : memref<2xf32> + memref.store %lhs, %alloc [%c0] : memref<2xf32> + memref.store %rhs, %alloc [%c1] : memref<2xf32> + %0 = memref.load %alloc[%c0] : memref<2xf32> + %1 = memref.load %alloc[%c1] : memref<2xf32> + %res = arith.addf %0, %1 : f32 + scf.reduce.return %res : f32 + } + } + func.return +} + +// CHECK-LABEL: func @reduce +// CHECK: scf.reduce +// CHECK: [[ALLOC:%.+]] = memref.alloc( +// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true +// CHECK: scf.reduce.return