Skip to content

Commit 3ab9de1

Browse files
committed
fix live contexts
1 parent 6e5cd75 commit 3ab9de1

File tree

3 files changed

+24
-38
lines changed

3 files changed

+24
-38
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
15201520
llvm::ArrayRef<MlirValue> operands,
15211521
std::optional<nb::dict> attributes,
15221522
std::optional<std::vector<PyBlock *>> successors,
1523-
int regions, PyLocation location,
1523+
int regions, PyLocation &location,
15241524
const nb::object &maybeIp, bool inferType) {
15251525
llvm::SmallVector<MlirType, 4> mlirResults;
15261526
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1934,7 +1934,7 @@ nb::object PyOpView::buildGeneric(
19341934
std::optional<nb::list> resultTypeList, nb::list operandList,
19351935
std::optional<nb::dict> attributes,
19361936
std::optional<std::vector<PyBlock *>> successors,
1937-
std::optional<int> regions, PyLocation location,
1937+
std::optional<int> regions, PyLocation &location,
19381938
const nb::object &maybeIp) {
19391939
PyMlirContextRef context = location.getContext();
19401940

@@ -2795,13 +2795,12 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
27952795
thread_local std::array<MlirLocation, kMaxFrames> frames;
27962796
size_t count = 0;
27972797

2798-
assert(PyGILState_Check());
2799-
2798+
nb::gil_scoped_acquire acquire;
28002799
PyThreadState *tstate = PyThreadState_GET();
28012800

28022801
PyFrameObject *next;
28032802
for (PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2804-
pyFrame != nullptr && count < framesLimit;
2803+
pyFrame != nullptr && count < kMaxFrames;
28052804
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
28062805
PyCodeObject *code = PyFrame_GetCode(pyFrame);
28072806
auto fileNameStr =
@@ -2834,8 +2833,10 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
28342833

28352834
frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
28362835
++count;
2837-
if (count > framesLimit)
2836+
if (count > framesLimit) {
2837+
Py_XDECREF(pyFrame);
28382838
break;
2839+
}
28392840
}
28402841

28412842
if (count == 0)
@@ -2856,22 +2857,15 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
28562857

28572858
PyLocation
28582859
maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2859-
MlirLocation mlirLoc;
2860-
MlirContext mlirCtx;
2861-
if (!location.has_value() &&
2862-
PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) {
2863-
mlirCtx = DefaultingPyMlirContext::resolve().get();
2864-
mlirLoc = tracebackToLocation(mlirCtx);
2865-
} else if (!location.has_value()) {
2866-
mlirLoc = DefaultingPyLocation::resolve();
2867-
mlirCtx = mlirLocationGetContext(mlirLoc);
2868-
} else {
2869-
mlirLoc = *location;
2870-
mlirCtx = mlirLocationGetContext(mlirLoc);
2871-
}
2872-
assert(!mlirLocationIsNull(mlirLoc) && "expected non-null mlirLoc");
2873-
PyMlirContextRef ctx = PyMlirContext::forContext(mlirCtx);
2874-
return {ctx, mlirLoc};
2860+
if (location.has_value())
2861+
return location.value();
2862+
if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2863+
return DefaultingPyLocation::resolve();
2864+
2865+
PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2866+
MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2867+
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
2868+
return {ref, mlirLoc};
28752869
}
28762870

28772871
} // namespace
@@ -3325,7 +3319,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
33253319
kModuleParseDocstring)
33263320
.def_static(
33273321
"create",
3328-
[](std::optional<PyLocation> loc) {
3322+
[](const std::optional<PyLocation> &loc) {
33293323
PyLocation pyLoc = maybeGetTracebackLocation(loc);
33303324
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
33313325
return PyModule::forModule(module).releaseObject();
@@ -3540,8 +3534,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35403534
std::optional<std::vector<PyValue *>> operands,
35413535
std::optional<nb::dict> attributes,
35423536
std::optional<std::vector<PyBlock *>> successors, int regions,
3543-
std::optional<PyLocation> location, const nb::object &maybeIp,
3544-
bool inferType) {
3537+
const std::optional<PyLocation> &location,
3538+
const nb::object &maybeIp, bool inferType) {
35453539
// Unpack/validate operands.
35463540
llvm::SmallVector<MlirValue, 4> mlirOperands;
35473541
if (operands) {
@@ -3599,7 +3593,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35993593
std::optional<nb::list> resultTypeList, nb::list operandList,
36003594
std::optional<nb::dict> attributes,
36013595
std::optional<std::vector<PyBlock *>> successors,
3602-
std::optional<int> regions, std::optional<PyLocation> location,
3596+
std::optional<int> regions,
3597+
const std::optional<PyLocation> &location,
36033598
const nb::object &maybeIp) {
36043599
PyLocation pyLoc = maybeGetTracebackLocation(location);
36053600
new (self) PyOpView(PyOpView::buildGeneric(

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,6 @@ class PyMlirContext {
192192
PyMlirContext(const PyMlirContext &) = delete;
193193
PyMlirContext(PyMlirContext &&) = delete;
194194

195-
/// For the case of a python __init__ (nanobind::init) method, pybind11 is
196-
/// quite strict about needing to return a pointer that is not yet associated
197-
/// to an nanobind::object. Since the forContext() method acts like a pool,
198-
/// possibly returning a recycled context, it does not satisfy this need. The
199-
/// usual way in python to accomplish such a thing is to override __new__, but
200-
/// that is also not supported by pybind11. Instead, we use this entry
201-
/// point which always constructs a fresh context (which cannot alias an
202-
/// existing one because it is fresh).
203-
static PyMlirContext *createNewContextForInit();
204-
205195
/// Returns a context reference for the singleton PyMlirContext wrapper for
206196
/// the given context.
207197
static PyMlirContextRef forContext(MlirContext context);
@@ -722,7 +712,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
722712
llvm::ArrayRef<MlirValue> operands,
723713
std::optional<nanobind::dict> attributes,
724714
std::optional<std::vector<PyBlock *>> successors, int regions,
725-
PyLocation location, const nanobind::object &ip, bool inferType);
715+
PyLocation &location, const nanobind::object &ip, bool inferType);
726716

727717
/// Creates an OpView suitable for this operation.
728718
nanobind::object createOpView();
@@ -780,7 +770,7 @@ class PyOpView : public PyOperationBase {
780770
nanobind::list operandList,
781771
std::optional<nanobind::dict> attributes,
782772
std::optional<std::vector<PyBlock *>> successors,
783-
std::optional<int> regions, PyLocation location,
773+
std::optional<int> regions, PyLocation &location,
784774
const nanobind::object &maybeIp);
785775

786776
/// Construct an instance of a class deriving from OpView, bypassing its

mlir/test/python/ir/auto_location.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def run(f):
1313
f()
1414
gc.collect()
1515
assert Context._get_live_count() == 0
16+
return f
1617

1718

1819
@contextmanager

0 commit comments

Comments
 (0)