diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d562da1f90757..001660ee51311 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool); +/// Gets the number of threads of the thread pool of the context when +/// multithreading is enabled. Returns 1 if no multithreading. +MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context); + +/// Gets the thread pool of the context when enabled multithreading, otherwise +/// an assertion is raised. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool +mlirContextGetThreadPool(MlirContext context); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 12793f7dd15be..22d6d117573b9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- + + // Expose DefaultThreadPool to python + nb::class_(m, "ThreadPool") + .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); + nb::class_(m, "_BaseContext") .def("__init__", [](PyMlirContext &self) { @@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirContextEnableMultithreading(self.get(), enable); }, nb::arg("enable")) + .def("set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); + mlirContextSetThreadPool(self.get(), pool.get()); + }) + .def("get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }) + .def("_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 1ed6240a6ca69..9befcce725bb7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -11,6 +11,7 @@ #define MLIR_BINDINGS_PYTHON_IRMODULES_H #include +#include #include #include @@ -22,9 +23,10 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ThreadPool.h" namespace mlir { namespace python { @@ -158,6 +160,29 @@ class PyThreadContextEntry { FrameKind frameKind; }; +/// Wrapper around MlirLlvmThreadPool +/// Python object owns the C++ thread pool +class PyThreadPool { +public: + PyThreadPool() { + ownedThreadPool = std::make_unique(); + } + PyThreadPool(const PyThreadPool &) = delete; + PyThreadPool(PyThreadPool &&) = delete; + + int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } + MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } + + std::string _mlir_thread_pool_ptr() const { + std::stringstream ss; + ss << ownedThreadPool.get(); + return ss.str(); + } + +private: + std::unique_ptr ownedThreadPool; +}; + /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; class PyMlirContext { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6cd9ba2aef233..649f3b7056fb0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context, unwrap(context)->setThreadPool(*unwrap(threadPool)); } +unsigned mlirContextGetNumThreads(MlirContext context) { + return unwrap(context)->getNumThreads(); +} + +MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { + return wrap(&unwrap(context)->getThreadPool()); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index d021dde05dd87..083a9075fe4c5 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -148,13 +148,25 @@ def process_initializer_module(module_name): break class Context(ir._BaseContext): - def __init__(self, load_on_create_dialects=None, *args, **kwargs): + def __init__( + self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs + ): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) + if disable_multithreading and thread_pool is not None: + raise ValueError( + "Context constructor has given thread_pool argument, " + "but disable_multithreading flag is True. " + "Please, set thread_pool argument to None or " + "set disable_multithreading flag to False." + ) if not disable_multithreading: - self.enable_multithreading(True) + if thread_pool is None: + self.enable_multithreading(True) + else: + self.set_thread_pool(thread_pool) if load_on_create_dialects is not None: logger.debug( "Loading all dialects from load_on_create_dialects arg %r", diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py index c20270999425e..230db8277c8e7 100644 --- a/mlir/test/python/ir/context_lifecycle.py +++ b/mlir/test/python/ir/context_lifecycle.py @@ -47,3 +47,26 @@ assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule) c5 = mlir.ir.Context._CAPICreate(c4_capsule) assert c4 is c5 +c4 = None +c5 = None +gc.collect() + +# Create a global threadpool and use it in two contexts +tp = mlir.ir.ThreadPool() +assert tp.get_max_concurrency() > 0 +c5 = mlir.ir.Context() +c5.set_thread_pool(tp) +assert c5.get_num_threads() == tp.get_max_concurrency() +assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() +c6 = mlir.ir.Context() +c6.set_thread_pool(tp) +assert c6.get_num_threads() == tp.get_max_concurrency() +assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() +c7 = mlir.ir.Context(thread_pool=tp) +assert c7.get_num_threads() == tp.get_max_concurrency() +assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() +assert mlir.ir.Context._get_live_count() == 3 +c5 = None +c6 = None +c7 = None +gc.collect()