diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 94f2002fc51fa..085ae4c93b829 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -68,8 +68,14 @@ bool linalg::isaCopyOpInterface(LinalgOp op) { !mapRange.back().isIdentity()) { return false; } - // Region. - return llvm::hasSingleElement(op.getBlock()->getOperations()); + // Check yield first block argument. + Block *body = op.getBlock(); + if (body->getOperations().size() != 1) + return false; + auto yieldOp = dyn_cast(body->back()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return false; + return yieldOp->getOperand(0) == body->getArgument(0); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir index 357f2c11a7936..5d66837fca510 100644 --- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir @@ -29,3 +29,20 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x } -> tensor<8xi32> return %res : tensor<8xi32> } + +// ----- + +#map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @not_copy +// CHECK-NOT: linalg.copy +// CHECK: linalg.generic +func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> { + %c0_i32 = arith.constant 0 : i32 + %res = linalg.generic { + indexing_maps = [#map, #map], iterator_types = ["parallel"] + } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) { + ^bb0(%in: i32, %out: i32): + linalg.yield %c0_i32 : i32 + } -> tensor<8xi32> + return %res : tensor<8xi32> +}