diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index e3e3e86231465..2d60036716611 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -146,6 +146,16 @@ class IROperand : public detail::IROperandBase { return *this; } + /// Two operands are equal if they have the same owner and the same operand + /// number. They are stored inside of ops, so it is valid to compare their + /// pointers to determine equality. + bool operator==(const IROperand &other) const { + return this == &other; + } + bool operator!=(const IROperand &other) const { + return !(*this == other); + } + /// Return the current value being used by this operand. IRValueT get() const { return value; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 5716dcc9d9050..52ff6ceeee85b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, bool MaterializeInDestinationOp::bufferizesToMemoryRead( OpOperand &opOperand, const AnalysisState &state) { - return &opOperand == &getSourceMutable(); + return opOperand == getSourceMutable(); } bool MaterializeInDestinationOp::bufferizesToMemoryWrite( OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getDestMutable()) { + if (opOperand == getDestMutable()) { assert(isa(getDest().getType()) && "expected tensor type"); return true; } @@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace( AliasingValueList MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getDestMutable()) { + if (opOperand == getDestMutable()) { assert(isa(getDest().getType()) && "expected tensor type"); return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 9386d0fd0f04f..a95443db88b50 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -644,11 +644,11 @@ struct InsertSliceOpInterface RankedTensorType destType = insertSliceOp.getDestType(); // The source is always read. - if (&opOperand == &insertSliceOp.getSourceMutable()) + if (opOperand == insertSliceOp.getSourceMutable()) return true; // For the destination, it depends... - assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest"); + assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> @@ -849,7 +849,7 @@ struct ReshapeOpInterface const AnalysisState &state) const { // Depending on the layout map, the source buffer may have to be copied. auto reshapeOp = cast(op); - return &opOperand == &reshapeOp.getShapeMutable(); + return opOperand == reshapeOp.getShapeMutable(); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto parallelInsertSliceOp = cast(op); - return &opOperand == ¶llelInsertSliceOp.getDestMutable(); + return opOperand == parallelInsertSliceOp.getDestMutable(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter,