diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index efd8d573936c3..86af59142b77d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -170,6 +170,12 @@ class TransformState { /// should be emitted when the value is used. using InvalidatedHandleMap = DenseMap>; +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + /// Debug only: A timestamp is associated with each transform IR value, so + /// that invalid iterator usage can be detected more reliably. + using TransformIRTimestampMapping = DenseMap; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { @@ -178,6 +184,11 @@ class TransformState { ParamMapping params; ValueMapping values; ValueMapping reverseValues; + +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + TransformIRTimestampMapping timestamps; + void incrementTimestamp(Value value) { ++timestamps[value]; } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; friend LogicalResult applyTransforms(Operation *, TransformOpInterface, @@ -207,10 +218,26 @@ class TransformState { /// not enumerated. This function is helpful for transformations that apply to /// a particular handle. auto getPayloadOps(Value value) const { + ArrayRef view = getPayloadOpsView(value); + +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Memorize the current timestamp and make sure that it has not changed + // when incrementing or dereferencing the iterator returned by this + // function. The timestamp is incremented when the "direct" mapping is + // resized; this would invalidate the iterator returned by this function. + int64_t currentTimestamp = getMapping(value).timestamps.lookup(value); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + // When ops are replaced/erased, they are replaced with nullptr (until // the data structure is compacted). Do not enumerate these ops. - return llvm::make_filter_range(getPayloadOpsView(value), - [](Operation *op) { return op != nullptr; }); + return llvm::make_filter_range(view, [=](Operation *op) { +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + bool sameTimestamp = + currentTimestamp == this->getMapping(value).timestamps.lookup(value); + assert(sameTimestamp && "iterator was invalidated during iteration"); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return op != nullptr; + }); } /// Returns the list of parameters that the given transform IR value diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 00450a1ff8f36..9cac178d3c2b8 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle, for (Operation *op : mappings.direct[opHandle]) dropMappingEntry(mappings.reverse, op, opHandle); mappings.direct.erase(opHandle); +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(opHandle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (Value opResult : origOpFlatResults) { SmallVector resultHandles; @@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping( Mappings &localMappings = getMapping(opHandle); dropMappingEntry(localMappings.direct, opHandle, payloadOp); dropMappingEntry(localMappings.reverse, payloadOp, opHandle); + +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + localMappings.incrementTimestamp(opHandle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } } } @@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef payload, void transform::TransformState::compactOpHandles() { for (Value handle : opHandlesToCompact) { Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); +#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS + if (llvm::find(mappings.direct[handle], nullptr) != + mappings.direct[handle].end()) + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(handle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS llvm::erase_value(mappings.direct[handle], nullptr); } opHandlesToCompact.clear();