Skip to content

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Mar 6, 2025

Description:

  • Exposed MlirLlvmThreadPool as PyThreadPool in MLIR Python bindings
  • Add tests

Context:

In JAX ir.Context are used with disabled multi-threading to avoid caching multiple threading pools:
https://github.com/jax-ml/jax/blob/623865fe9538100d877ba9d36f788d0f95a11ed2/jax/_src/interpreters/mlir.py#L606-L611
However, when context has enabled multithreading it also uses locks on the StorageUniquers and this can be helpful to avoid data races in the multi-threaded execution (for example with free-threaded cpython, jax-ml/jax#26272).
With this PR user can enable the multi-threading: 1) enables additional locking and 2) set a shared threading pool such that cached contexts can have one global pool.

cc @hawkinsp

@llvmbot llvmbot added the mlir label Mar 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-mlir

Author: vfdev (vfdev-5)

Changes

cc @hawkinsp


Full diff: https://github.com/llvm/llvm-project/pull/130109.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+10)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+19-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   // __init__.py will subclass it with site-specific functionality and set a
   // "Context" attribute on this module.
   //----------------------------------------------------------------------------
+
+  // Expose DefaultThreadPool to python
+  nb::class_<PyThreadPool>(m, "ThreadPool")
+      .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+      .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
   nb::class_<PyMlirContext>(m, "_BaseContext")
       .def("__init__",
            [](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             mlirContextEnableMultithreading(self.get(), enable);
           },
           nb::arg("enable"))
+      .def("set_thread_pool",
+           [](PyMlirContext &self, PyThreadPool &pool) {
+             mlirContextSetThreadPool(self.get(), pool.get());
+           })
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
 
 namespace mlir {
 namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
   FrameKind frameKind;
 };
 
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+  PyThreadPool() {
+    ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+  }
+  PyThreadPool(const PyThreadPool &) = delete;
+  PyThreadPool(PyThreadPool &&) = delete;
+
+  int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+  MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
 /// Wrapper around MlirContext.
 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {

@vfdev-5 vfdev-5 marked this pull request as draft March 6, 2025 13:52
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from cb796b9 to feedd57 Compare March 6, 2025 14:50
@vfdev-5 vfdev-5 marked this pull request as ready for review March 6, 2025 14:50
@llvmbot llvmbot added llvm:support bazel "Peripheral" support tier build system: utils/bazel labels Mar 6, 2025
@vfdev-5 vfdev-5 marked this pull request as draft March 6, 2025 14:56
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from feedd57 to 6e0960a Compare March 6, 2025 14:58
@vfdev-5 vfdev-5 marked this pull request as ready for review March 6, 2025 15:03
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 6e0960a to 998821b Compare March 7, 2025 10:47
@llvmbot llvmbot added the mlir:python MLIR Python bindings label Mar 7, 2025
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 998821b to cb63466 Compare March 7, 2025 10:49
@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Mar 7, 2025

@joker-eph in the last commit (7db1c50) I added thread_pool arg to mlir.ir.Context constructor. Also exposed get_num_threads and _mlir_thread_pool_ptr methods:

  • updated MLIR C API: added mlirContextGetNumThreads, mlirContextGetThreadPool methods

_mlir_thread_pool_ptr is an internal method to ensure that C++ thread pool is correctly set up.

Let me know if this works. Thanks!

Copy link

github-actions bot commented Mar 7, 2025

✅ With the latest revision this PR passed the Python code formatter.

@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from cb63466 to 7db1c50 Compare March 7, 2025 10:56
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with one minor thing to check.

@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 7db1c50 to a557554 Compare March 8, 2025 13:37
@joker-eph joker-eph merged commit ab18cc2 into llvm:main Mar 10, 2025
11 checks passed
@vfdev-5 vfdev-5 deleted the mlir-python-expose-threadpool branch March 10, 2025 10:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel llvm:support mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants