From d1f34bad1b00e310e84572296240303e9b664529 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 Dec 2023 21:27:48 -0800 Subject: [PATCH] [mlir][python] Make the Context/Operation capsule creation methods work as documented. --- mlir/lib/Bindings/Python/IRCore.cpp | 78 +++++++++++++++++++++--- mlir/lib/Bindings/Python/IRModule.h | 19 +++++- mlir/test/python/ir/context_lifecycle.py | 45 +++++++++++++- mlir/test/python/ir/operation.py | 13 ---- 4 files changed, 129 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5412c3dec4b1b..39757dfad5be1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) throw py::error_already_set(); - return forContext(rawContext).releaseObject(); + return stealExternalContext(rawContext).releaseObject(); } PyMlirContext *PyMlirContext::createNewContextForInit() { @@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { - // Create. - PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); - liveContexts[context.ptr] = unownedContextWrapper; - return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); + throw std::runtime_error( + "Cannot use a context that is not owned by the Python bindings."); } + // Use existing. py::object pyRef = py::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } +PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) { + py::gil_scoped_acquire acquire; + auto &liveContexts = getLiveContexts(); + auto it = liveContexts.find(context.ptr); + if (it != liveContexts.end()) { + throw std::runtime_error( + "Cannot transfer ownership of the context to Python " + "as it is already owned by Python."); + } + + PyMlirContext *unownedContextWrapper = new PyMlirContext(context); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedContextWrapper, py::return_value_policy::take_ownership); + assert(pyRef && "cast to py::object failed"); + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); +} + PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; @@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, return PyOperationRef(existing, std::move(pyRef)); } +PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef, + MlirOperation operation) { + auto &liveOperations = contextRef->liveOperations; + auto it = liveOperations.find(operation.ptr); + if (it != liveOperations.end()) { + throw std::runtime_error( + "Cannot transfer ownership of the operation to Python " + "as it is already owned by Python."); + } + return createInstance(std::move(contextRef), operation, py::none()); +} + PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, py::object parentKeepAlive) { @@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) { if (mlirOperationIsNull(rawOperation)) throw py::error_already_set(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); - return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) + return stealExternalOperation(PyMlirContext::forContext(rawCtxt), + rawOperation) .releaseObject(); } @@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) { .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) + .def_static("_testing_create_raw_context_capsule", + []() { + // Creates an MlirContext not known to the Python bindings + // and puts it in a capsule. Used to test interop. Using + // this without passing it back to the capsule creation + // API will leak. + return py::reinterpret_steal( + mlirPythonContextToCapsule( + mlirContextCreateWithThreading(false))); + }) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) @@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool>( - &PyOperationBase::print), + bool, py::object, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) + .def_static( + "_testing_create_raw_capsule", + [](std::string sourceStr) { + // Creates a raw context and an operation via parsing the given + // source and returns them in a capsule. Error handling is + // minimal as this is purely intended for testing interop with + // operation creation from capsule functions. + MlirContext context = mlirContextCreateWithThreading(false); + MlirOperation op = mlirOperationCreateParse( + context, toMlirStringRef(sourceStr), toMlirStringRef("temp")); + if (mlirOperationIsNull(op)) { + mlirContextDestroy(context); + throw std::invalid_argument("Failed to parse"); + } + return py::make_tuple(py::reinterpret_steal( + mlirPythonContextToCapsule(context)), + py::reinterpret_steal( + mlirPythonOperationToCapsule(op))); + }) .def_property_readonly("operation", [](py::object self) { return self; }) .def_property_readonly("opview", &PyOperation::createOpView) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 79b7e0c96188c..04164b78b3e25 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -176,8 +176,19 @@ class PyMlirContext { static PyMlirContext *createNewContextForInit(); /// Returns a context reference for the singleton PyMlirContext wrapper for - /// the given context. + /// the given context. It is only valid to call this on an MlirContext that + /// is already owned by the Python bindings. Typically this will be because + /// it came in some fashion from createNewContextForInit(). However, it + /// is also possible to explicitly transfer ownership of an existing + /// MlirContext to the Python bindings via stealExternalContext(). static PyMlirContextRef forContext(MlirContext context); + + /// Explicitly takes ownership of an MlirContext that must not already be + /// known to the Python bindings. Once done, the life-cycle of the context + /// will be controlled by the Python bindings, and it will be destroyed + /// when the reference count goes to zero. + static PyMlirContextRef stealExternalContext(MlirContext context); + ~PyMlirContext(); /// Accesses the underlying MlirContext. @@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject { forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Explicitly takes ownership of an operation that must not already be known + /// to the Python bindings. Once done, the life-cycle of the operation + /// will be controlled by the Python bindings. + static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef, + MlirOperation operation); + /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py index c20270999425e..fbd1851ba70ae 100644 --- a/mlir/test/python/ir/context_lifecycle.py +++ b/mlir/test/python/ir/context_lifecycle.py @@ -45,5 +45,46 @@ c4 = mlir.ir.Context() c4_capsule = c4._CAPIPtr assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule) -c5 = mlir.ir.Context._CAPICreate(c4_capsule) -assert c4 is c5 +# Because the context is already owned by Python, it cannot be created +# a second time. +try: + c5 = mlir.ir.Context._CAPICreate(c4_capsule) +except RuntimeError: + pass +else: + raise AssertionError( + "Should have gotten a RuntimeError when attempting to " + "re-create an already owned context" + ) +c4 = None +c4_capsule = None +gc.collect() +assert mlir.ir.Context._get_live_count() == 0 + +# Use a private testing method to create an unowned context capsule and +# import it. +c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule() +c6 = mlir.ir.Context._CAPICreate(c6_capsule) +assert mlir.ir.Context._get_live_count() == 1 +c6_capsule = None +c6 = None +gc.collect() +assert mlir.ir.Context._get_live_count() == 0 + +# Also test operation import/export as it is tightly coupled to the context. +( + raw_context_capsule, + raw_operation_capsule, +) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}") +assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule) +# Attempting to import an operation for an unknown context should fail. +try: + mlir.ir.Operation._CAPICreate(raw_operation_capsule) +except RuntimeError: + pass +else: + raise AssertionError("Expected exception for unknown context") + +# Try again having imported the context. +c7 = mlir.ir.Context._CAPICreate(raw_context_capsule) +op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 04f8a9936e31f..f59b1a26ba48b 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -844,19 +844,6 @@ def testOperationName(): print(op.operation.name) -# CHECK-LABEL: TEST: testCapsuleConversions -@run -def testCapsuleConversions(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - m = Operation.create("custom.op1").operation - m_capsule = m._CAPIPtr - assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) - m2 = Operation._CAPICreate(m_capsule) - assert m2 is m - - # CHECK-LABEL: TEST: testOperationErase @run def testOperationErase():