Skip to content

Commit 5a600c2

Browse files
authored
[mlir][python] Expose PyInsertionPoint's reference operation (#69082)
The reason I want this is that I am writing my own Python bindings and would like to use the insertion point from `PyThreadContextEntry::getDefaultInsertionPoint()` to call C++ functions that take an `OpBuilder` (I don't need to expose it in Python but it also seems appropriate). AFAICT, there is currently no way to translate a `PyInsertionPoint` into an `OpBuilder` because the operation is inaccessible.
1 parent a3a0f59 commit 5a600c2

File tree

4 files changed

+23
-1
lines changed

4 files changed

+23
-1
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3207,7 +3207,18 @@ void mlir::python::populateIRCore(py::module &m) {
32073207
"Inserts an operation.")
32083208
.def_property_readonly(
32093209
"block", [](PyInsertionPoint &self) { return self.getBlock(); },
3210-
"Returns the block that this InsertionPoint points to.");
3210+
"Returns the block that this InsertionPoint points to.")
3211+
.def_property_readonly(
3212+
"ref_operation",
3213+
[](PyInsertionPoint &self) -> py::object {
3214+
auto ref_operation = self.getRefOperation();
3215+
if (ref_operation)
3216+
return ref_operation->getObject();
3217+
return py::none();
3218+
},
3219+
"The reference operation before which new operations are "
3220+
"inserted, or None if the insertion point is at the end of "
3221+
"the block");
32113222

32123223
//----------------------------------------------------------------------------
32133224
// Mapping of PyAttribute.

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ class PyInsertionPoint {
833833
const pybind11::object &excTb);
834834

835835
PyBlock &getBlock() { return block; }
836+
std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
836837

837838
private:
838839
// Trampoline constructor that avoids null initializing members while

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,8 @@ class InsertionPoint:
755755
def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
756756
@property
757757
def block(self) -> Block: ...
758+
@property
759+
def ref_operation(self) -> Optional[_OperationBase]: ...
758760

759761
# TODO: Auto-generated. Audit and fix.
760762
class IntegerAttr(Attribute):

mlir/test/python/ir/insertion_point.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def test_insert_at_block_end():
2727
)
2828
entry_block = module.body.operations[0].regions[0].blocks[0]
2929
ip = InsertionPoint(entry_block)
30+
assert ip.block == entry_block
31+
assert ip.ref_operation is None
3032
ip.insert(Operation.create("custom.op2"))
3133
# CHECK: "custom.op1"
3234
# CHECK: "custom.op2"
@@ -51,6 +53,8 @@ def test_insert_before_operation():
5153
)
5254
entry_block = module.body.operations[0].regions[0].blocks[0]
5355
ip = InsertionPoint(entry_block.operations[1])
56+
assert ip.block == entry_block
57+
assert ip.ref_operation == entry_block.operations[1]
5458
ip.insert(Operation.create("custom.op3"))
5559
# CHECK: "custom.op1"
5660
# CHECK: "custom.op3"
@@ -75,6 +79,8 @@ def test_insert_at_block_begin():
7579
)
7680
entry_block = module.body.operations[0].regions[0].blocks[0]
7781
ip = InsertionPoint.at_block_begin(entry_block)
82+
assert ip.block == entry_block
83+
assert ip.ref_operation == entry_block.operations[0]
7884
ip.insert(Operation.create("custom.op1"))
7985
# CHECK: "custom.op1"
8086
# CHECK: "custom.op2"
@@ -108,6 +114,8 @@ def test_insert_at_terminator():
108114
)
109115
entry_block = module.body.operations[0].regions[0].blocks[0]
110116
ip = InsertionPoint.at_block_terminator(entry_block)
117+
assert ip.block == entry_block
118+
assert ip.ref_operation == entry_block.operations[1]
111119
ip.insert(Operation.create("custom.op2"))
112120
# CHECK: "custom.op1"
113121
# CHECK: "custom.op2"

0 commit comments

Comments
 (0)