diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 114d79555dcef..efd8d573936c3 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform); void modifiesPayload(SmallVectorImpl &effects); void onlyReadsPayload(SmallVectorImpl &effects); +/// Checks whether the transform op modifies the payload. +bool doesModifyPayload(transform::TransformOpInterface transform); +/// Checks whether the transform op reads the payload. +bool doesReadPayload(transform::TransformOpInterface transform); + /// Populates `consumedArguments` with positions of `block` arguments that are /// consumed by the operations in the `block`. void getConsumedBlockArguments( diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index ed987ac4b5164..00450a1ff8f36 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload( effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); } +bool transform::doesModifyPayload(transform::TransformOpInterface transform) { + auto iface = cast(transform.getOperation()); + SmallVector effects; + iface.getEffects(effects); + return ::hasEffect(effects); +} + +bool transform::doesReadPayload(transform::TransformOpInterface transform) { + auto iface = cast(transform.getOperation()); + SmallVector effects; + iface.getEffects(effects); + return ::hasEffect(effects); +} + void transform::getConsumedBlockArguments( Block &block, llvm::SmallDenseSet &consumedArguments) { SmallVector effects; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 7bbbbba4134b1..a56adcfd7fd84 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector> resultOps(getNumResults(), {}); - - for (Operation *op : state.getPayloadOps(getTarget())) { + // Store payload ops in a vector because ops may be removed from the mapping + // by the TrackingRewriter while the iteration is in progress. + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + for (Operation *op : targets) { auto scope = state.make_region_scope(getBody()); if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects( SmallVectorImpl &effects) { BlockArgument iterVar = getIterationVariable(); if (any_of(getBody().front().without_terminator(), [&](Operation &op) { + return isHandleConsumed(iterVar, cast(&op)); })) { consumesHandle(getTarget(), effects); @@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects( onlyReadsHandle(getTarget(), effects); } + if (any_of(getBody().front().without_terminator(), [&](Operation &op) { + return doesModifyPayload(cast(&op)); + })) { + modifiesPayload(effects); + } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { + return doesReadPayload(cast(&op)); + })) { + onlyReadsPayload(effects); + } + for (Value result : getResults()) producesHandle(result, effects); } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index db97d0a088757..68e3a48515396 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -691,6 +691,28 @@ transform.with_pdl_patterns { // ----- +// CHECK-LABEL: func @consume_in_foreach() +// CHECK-NEXT: return +func.func @consume_in_foreach() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %2 = arith.constant 2 : index + %3 = arith.constant 3 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.foreach %f : !transform.any_op { + ^bb2(%arg2: !transform.any_op): + // expected-remark @below {{erasing}} + transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op + } +} + +// ----- + func.func @bar() { scf.execute_region { // expected-remark @below {{transform applied}} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index afd5011f17c6d..21f9ff5999a5e 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) - op->erase(); + rewriter.eraseOp(op); if (getFailAfterErase()) return emitSilenceableError() << "silenceable error";