Skip to content

[mlir][python] automatic location inference #151246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 12, 2025
39 changes: 39 additions & 0 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include <optional>
#include <regex>
#include <string>
#include <unordered_set>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Regex.h"

namespace mlir {
namespace python {
Expand Down Expand Up @@ -114,6 +118,39 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);

class TracebackLoc {
public:
bool locTracebacksEnabled();

void setLocTracebacksEnabled(bool value);

size_t locTracebackFramesLimit();

void setLocTracebackFramesLimit(size_t value);

void registerTracebackFileInclusion(const std::string &file);

void registerTracebackFileExclusion(const std::string &file);

bool isUserTracebackFilename(llvm::StringRef file);

static constexpr size_t kMaxFrames = 512;

private:
nanobind::ft_mutex mutex;
bool locTracebackEnabled_ = false;
size_t locTracebackFramesLimit_ = 10;
std::unordered_set<std::string> userTracebackIncludeFiles;
std::unordered_set<std::string> userTracebackExcludeFiles;
std::regex userTracebackIncludeRegex;
bool rebuildUserTracebackIncludeRegex = false;
std::regex userTracebackExcludeRegex;
bool rebuildUserTracebackExcludeRegex = false;
llvm::StringMap<bool> isUserTracebackFilenameCache;
};

TracebackLoc &getTracebackLoc() { return tracebackLoc; }

private:
static PyGlobals *instance;

Expand All @@ -134,6 +171,8 @@ class PyGlobals {
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;

TracebackLoc tracebackLoc;
};

} // namespace python
Expand Down
122 changes: 104 additions & 18 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <system_error>
#include <utility>

namespace nb = nanobind;
using namespace nb::literals;
Expand Down Expand Up @@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
int regions, PyLocation &location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
Expand Down Expand Up @@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);

return created.getObject();
Expand Down Expand Up @@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions, PyLocation &location,
const nb::object &maybeIp) {
PyMlirContextRef context = location->getContext();
PyMlirContextRef context = location.getContext();

// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
Expand Down Expand Up @@ -2789,6 +2786,90 @@ class PyOpAttributeMap {
PyOperationRef operation;
};

MlirLocation tracebackToLocation(MlirContext ctx) {
size_t framesLimit =
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
// Use a thread_local here to avoid requiring a large amount of space.
thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
frames;
size_t count = 0;

nb::gil_scoped_acquire acquire;
PyThreadState *tstate = PyThreadState_GET();
PyFrameObject *next;
PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
// In the increment expression:
// 1. get the next prev frame;
// 2. decrement the ref count on the current frame (in order that it can get
// gc'd, along with any objects in its closure and etc);
// 3. set current = next.
for (; pyFrame != nullptr && count < framesLimit;
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
PyCodeObject *code = PyFrame_GetCode(pyFrame);
auto fileNameStr =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
llvm::StringRef fileName(fileNameStr);
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;

#if PY_VERSION_HEX < 0x030b00f0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
llvm::StringRef funcName(name);
int startLine = PyFrame_GetLineNumber(pyFrame);
MlirLocation loc =
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
#else
// co_qualname and PyCode_Addr2Location added in py3.11
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
llvm::StringRef funcName(name);
int startLine, startCol, endLine, endCol;
int lasti = PyFrame_GetLasti(pyFrame);
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
&endCol)) {
throw nb::python_error();
}
MlirLocation loc = mlirLocationFileLineColRangeGet(
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
#endif

frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
++count;
}
// When the loop breaks (after the last iter), current frame (if non-null)
// is leaked without this.
Py_XDECREF(pyFrame);

if (count == 0)
return mlirLocationUnknownGet(ctx);

MlirLocation callee = frames[0];
assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
if (count == 1)
return callee;

MlirLocation caller = frames[count - 1];
assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
for (int i = count - 2; i >= 1; i--)
caller = mlirLocationCallSiteGet(frames[i], caller);

return mlirLocationCallSiteGet(callee, caller);
}

PyLocation
maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
if (location.has_value())
return location.value();
if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
return DefaultingPyLocation::resolve();

PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
MlirLocation mlirLoc = tracebackToLocation(ctx.get());
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}

} // namespace

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3052,10 +3133,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
.def_prop_ro_static(
"current",
[](nb::object & /*class*/) {
[](nb::object & /*class*/) -> std::optional<PyLocation *> {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
throw nb::value_error("No current Location");
return std::nullopt;
Comment on lines +3136 to +3139
Copy link
Contributor Author

@makslevental makslevental Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change Location.current to return None instead of throwing - see above

return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
Expand Down Expand Up @@ -3240,8 +3321,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
[](DefaultingPyLocation loc) {
MlirModule module = mlirModuleCreateEmpty(loc);
[](const std::optional<PyLocation> &loc) {
PyLocation pyLoc = maybeGetTracebackLocation(loc);
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
Expand Down Expand Up @@ -3454,8 +3536,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const nb::object &maybeIp,
bool inferType) {
const std::optional<PyLocation> &location,
const nb::object &maybeIp, bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
Expand All @@ -3467,8 +3549,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}

PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOperation::create(name, results, mlirOperands, attributes,
successors, regions, location, maybeIp,
successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
Expand Down Expand Up @@ -3512,12 +3595,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions,
const std::optional<PyLocation> &location,
const nb::object &maybeIp) {
PyLocation pyLoc = maybeGetTracebackLocation(location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
attributes, successors, regions, location, maybeIp));
attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
Expand Down Expand Up @@ -3551,17 +3636,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
regions, location, maybeIp);
regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
Expand Down
70 changes: 69 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

#include "Globals.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}

bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackEnabled_;
}

void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackEnabled_ = value;
}

size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackFramesLimit_;
}

void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackFramesLimit_ = std::min(value, kMaxFrames);
}

void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
if (userTracebackIncludeFiles.insert(reg).second)
rebuildUserTracebackIncludeRegex = true;
if (userTracebackExcludeFiles.count(reg)) {
if (userTracebackExcludeFiles.erase(reg))
rebuildUserTracebackExcludeRegex = true;
}
}

void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
if (userTracebackExcludeFiles.insert(reg).second)
rebuildUserTracebackExcludeRegex = true;
if (userTracebackIncludeFiles.count(reg)) {
if (userTracebackIncludeFiles.erase(reg))
rebuildUserTracebackIncludeRegex = true;
}
}

bool PyGlobals::TracebackLoc::isUserTracebackFilename(
const llvm::StringRef file) {
nanobind::ft_lock_guard lock(mutex);
if (rebuildUserTracebackIncludeRegex) {
userTracebackIncludeRegex.assign(
llvm::join(userTracebackIncludeFiles, "|"));
rebuildUserTracebackIncludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (rebuildUserTracebackExcludeRegex) {
userTracebackExcludeRegex.assign(
llvm::join(userTracebackExcludeFiles, "|"));
rebuildUserTracebackExcludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (!isUserTracebackFilenameCache.contains(file)) {
std::string fileStr = file.str();
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
isUserTracebackFilenameCache[file] = include || !exclude;
}
return isUserTracebackFilenameCache[file];
}
Loading