Skip to content

Commit 7c85086

Browse files
authored
[mlir][python] value casting (#69644)
This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a proxy class that overloads dunders such as `__add__`, `__sub__`, and `__mul__` for fun and great profit. This is thematically similar to bfb1ba7 and 9566ee2. The example in the test demonstrates the value of the feature (no pun intended): ```python @register_value_caster(F16Type.static_typeid) @register_value_caster(F32Type.static_typeid) @register_value_caster(F64Type.static_typeid) @register_value_caster(IntegerType.static_typeid) class ArithValue(Value): __add__ = partialmethod(_binary_op, op="add") __sub__ = partialmethod(_binary_op, op="sub") __mul__ = partialmethod(_binary_op, op="mul") a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) b = a + a # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) print(b) a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) b = a - a # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) print(b) a = arith.constant(value=FloatAttr.get(f64_t, 42.42)) b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b) ``` **EDIT**: this now goes through the bindings and thus supports automatic casting of `OpResult` (including as an element of `OpResultList`), `BlockArgument` (including as an element of `BlockArgumentList`), as well as `Value`.
1 parent 867ece1 commit 7c85086

File tree

16 files changed

+371
-58
lines changed

16 files changed

+371
-58
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,28 @@
118118

119119
/** Attribute on main C extension module (_mlir) that corresponds to the
120120
* type caster registration binding. The signature of the function is:
121-
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
122-
* bool replace)
123-
* where replace indicates the typeCaster should replace any existing registered
124-
* type casters (such as those for upstream ConcreteTypes).
121+
* def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
122+
* which then takes a typeCaster (register_type_caster is meant to be used as a
123+
* decorator from python), and where replace indicates the typeCaster should
124+
* replace any existing registered type casters (such as those for upstream
125+
* ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
126+
* -> SubClassTypeT where SubClassTypeT indicates the result should be a
127+
* subclass (inherit from) ir.Type.
125128
*/
126129
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
127130

131+
/** Attribute on main C extension module (_mlir) that corresponds to the
132+
* value caster registration binding. The signature of the function is:
133+
* def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
134+
* which then takes a valueCaster (register_value_caster is meant to be used as
135+
* a decorator, from python), and where replace indicates the valueCaster should
136+
* replace any existing registered value casters. The interface of the
137+
* valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
138+
* SubClassValueT indicates the result should be a subclass (inherit from)
139+
* ir.Value.
140+
*/
141+
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"
142+
128143
/// Gets a void* from a wrapped struct. Needed because const cast is different
129144
/// between C/C++.
130145
#ifdef __cplusplus

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
234234
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
235235
.attr("Value")
236236
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
237+
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
237238
.release();
238239
};
239240
};
@@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
496497
if (getTypeIDFunction) {
497498
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
498499
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
499-
getTypeIDFunction(),
500-
pybind11::cpp_function(
501-
[thisClass = thisClass](const py::object &mlirType) {
502-
return thisClass(mlirType);
503-
}));
500+
getTypeIDFunction())(pybind11::cpp_function(
501+
[thisClass = thisClass](const py::object &mlirType) {
502+
return thisClass(mlirType);
503+
}));
504504
}
505505
}
506506
};

mlir/lib/Bindings/Python/Globals.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ class PyGlobals {
6666
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
6767
bool replace = false);
6868

69+
/// Adds a user-friendly value caster. Raises an exception if the mapping
70+
/// already exists and replace == false. This is intended to be called by
71+
/// implementation code.
72+
void registerValueCaster(MlirTypeID mlirTypeID,
73+
pybind11::function valueCaster,
74+
bool replace = false);
75+
6976
/// Adds a concrete implementation dialect class.
7077
/// Raises an exception if the mapping already exists.
7178
/// This is intended to be called by implementation code.
@@ -86,6 +93,10 @@ class PyGlobals {
8693
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
8794
MlirDialect dialect);
8895

96+
/// Returns the custom value caster for MlirTypeID mlirTypeID.
97+
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
98+
MlirDialect dialect);
99+
89100
/// Looks up a registered dialect class by namespace. Note that this may
90101
/// trigger loading of the defining module and can arbitrarily re-enter.
91102
std::optional<pybind11::object>
@@ -109,7 +120,8 @@ class PyGlobals {
109120
llvm::StringMap<pybind11::object> attributeBuilderMap;
110121
/// Map of MlirTypeID to custom type caster.
111122
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
112-
123+
/// Map of MlirTypeID to custom value caster.
124+
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
113125
/// Set of dialect namespaces that we have attempted to import implementation
114126
/// modules for.
115127
llvm::StringSet<> loadedDialectModules;

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
18991899
}
19001900

19011901
//------------------------------------------------------------------------------
1902-
// PyValue and subclases.
1902+
// PyValue and subclasses.
19031903
//------------------------------------------------------------------------------
19041904

19051905
pybind11::object PyValue::getCapsule() {
19061906
return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
19071907
}
19081908

1909+
pybind11::object PyValue::maybeDownCast() {
1910+
MlirType type = mlirValueGetType(get());
1911+
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1912+
assert(!mlirTypeIDIsNull(mlirTypeID) &&
1913+
"mlirTypeID was expected to be non-null.");
1914+
std::optional<pybind11::function> valueCaster =
1915+
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
1916+
// py::return_value_policy::move means use std::move to move the return value
1917+
// contents into a new instance that will be owned by Python.
1918+
py::object thisObj = py::cast(this, py::return_value_policy::move);
1919+
if (!valueCaster)
1920+
return thisObj;
1921+
return valueCaster.value()(thisObj);
1922+
}
1923+
19091924
PyValue PyValue::createFromCapsule(pybind11::object capsule) {
19101925
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
19111926
if (mlirValueIsNull(value))
@@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue {
21212136
return DerivedTy::isaFunction(otherValue);
21222137
},
21232138
py::arg("other_value"));
2139+
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
2140+
[](DerivedTy &self) { return self.maybeDownCast(); });
21242141
DerivedTy::bindDerived(cls);
21252142
}
21262143

@@ -2193,6 +2210,7 @@ class PyBlockArgumentList
21932210
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
21942211
public:
21952212
static constexpr const char *pyClassName = "BlockArgumentList";
2213+
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
21962214

21972215
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
21982216
intptr_t startIndex = 0, intptr_t length = -1,
@@ -2241,6 +2259,7 @@ class PyBlockArgumentList
22412259
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
22422260
public:
22432261
static constexpr const char *pyClassName = "OpOperandList";
2262+
using SliceableT = Sliceable<PyOpOperandList, PyValue>;
22442263

22452264
PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
22462265
intptr_t length = -1, intptr_t step = 1)
@@ -2296,14 +2315,15 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
22962315
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
22972316
public:
22982317
static constexpr const char *pyClassName = "OpResultList";
2318+
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
22992319

23002320
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
23012321
intptr_t length = -1, intptr_t step = 1)
23022322
: Sliceable(startIndex,
23032323
length == -1 ? mlirOperationGetNumResults(operation->get())
23042324
: length,
23052325
step),
2306-
operation(operation) {}
2326+
operation(std::move(operation)) {}
23072327

23082328
static void bindDerived(ClassTy &c) {
23092329
c.def_property_readonly("types", [](PyOpResultList &self) {
@@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) {
28922912
.str());
28932913
}
28942914
return PyOpResult(operation.getRef(),
2895-
mlirOperationGetResult(operation, 0));
2915+
mlirOperationGetResult(operation, 0))
2916+
.maybeDownCast();
28962917
},
28972918
"Shortcut to get an op result if it has only one (throws an error "
28982919
"otherwise).")
@@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) {
35663587
[](PyValue &self, PyValue &with) {
35673588
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
35683589
},
3569-
kValueReplaceAllUsesWithDocstring);
3590+
kValueReplaceAllUsesWithDocstring)
3591+
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3592+
[](PyValue &self) { return self.maybeDownCast(); });
35703593
PyBlockArgument::bind(m);
35713594
PyOpResult::bind(m);
35723595
PyOpOperand::bind(m);

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
8888
found = std::move(typeCaster);
8989
}
9090

91+
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
92+
pybind11::function valueCaster,
93+
bool replace) {
94+
pybind11::object &found = valueCasterMap[mlirTypeID];
95+
if (found && !replace)
96+
throw std::runtime_error("Value caster is already registered: " +
97+
py::repr(found).cast<std::string>());
98+
found = std::move(valueCaster);
99+
}
100+
91101
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
92102
py::object pyClass) {
93103
py::object &found = dialectClassMap[dialectNamespace];
@@ -134,6 +144,17 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
134144
return std::nullopt;
135145
}
136146

147+
std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
148+
MlirDialect dialect) {
149+
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
150+
const auto foundIt = valueCasterMap.find(mlirTypeID);
151+
if (foundIt != valueCasterMap.end()) {
152+
assert(foundIt->second && "value caster is defined");
153+
return foundIt->second;
154+
}
155+
return std::nullopt;
156+
}
157+
137158
std::optional<py::object>
138159
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
139160
// Make sure dialect module is loaded.

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ class PyRegion {
761761

762762
/// Wrapper around an MlirAsmState.
763763
class PyAsmState {
764-
public:
764+
public:
765765
PyAsmState(MlirValue value, bool useLocalScope) {
766766
flags = mlirOpPrintingFlagsCreate();
767767
// The OpPrintingFlags are not exposed Python side, create locally and
@@ -780,16 +780,14 @@ class PyAsmState {
780780
state =
781781
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
782782
}
783-
~PyAsmState() {
784-
mlirOpPrintingFlagsDestroy(flags);
785-
}
783+
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
786784
// Delete copy constructors.
787785
PyAsmState(PyAsmState &other) = delete;
788786
PyAsmState(const PyAsmState &other) = delete;
789787

790788
MlirAsmState get() { return state; }
791789

792-
private:
790+
private:
793791
MlirAsmState state;
794792
MlirOpPrintingFlags flags;
795793
};
@@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy {
11121110
/// bindings so such operation always exists).
11131111
class PyValue {
11141112
public:
1113+
// The virtual here is "load bearing" in that it enables RTTI
1114+
// for PyConcreteValue CRTP classes that support maybeDownCast.
1115+
// See PyValue::maybeDownCast.
1116+
virtual ~PyValue() = default;
11151117
PyValue(PyOperationRef parentOperation, MlirValue value)
11161118
: parentOperation(std::move(parentOperation)), value(value) {}
11171119
operator MlirValue() const { return value; }
@@ -1124,6 +1126,8 @@ class PyValue {
11241126
/// Gets a capsule wrapping the void* within the MlirValue.
11251127
pybind11::object getCapsule();
11261128

1129+
pybind11::object maybeDownCast();
1130+
11271131
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
11281132
/// the underlying MlirValue is still tied to the owning operation.
11291133
static PyValue createFromCapsule(pybind11::object capsule);

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include "IRModule.h"
1313
#include "Pass.h"
1414

15-
#include <tuple>
16-
1715
namespace py = pybind11;
1816
using namespace mlir;
1917
using namespace py::literals;
@@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
4644
"dialect_namespace"_a, "dialect_class"_a,
4745
"Testing hook for directly registering a dialect")
4846
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
49-
"operation_name"_a, "operation_class"_a, "replace"_a = false,
47+
"operation_name"_a, "operation_class"_a, py::kw_only(),
48+
"replace"_a = false,
5049
"Testing hook for directly registering an operation");
5150

5251
// Aside from making the globals accessible to python, having python manage
@@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) {
8281
return opClass;
8382
});
8483
},
85-
"dialect_class"_a, "replace"_a = false,
84+
"dialect_class"_a, py::kw_only(), "replace"_a = false,
8685
"Produce a class decorator for registering an Operation class as part of "
8786
"a dialect");
8887
m.def(
8988
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
90-
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
91-
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
92-
replace);
89+
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
90+
return py::cpp_function([mlirTypeID,
91+
replace](py::object typeCaster) -> py::object {
92+
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
93+
return typeCaster;
94+
});
9395
},
94-
"typeid"_a, "type_caster"_a, "replace"_a = false,
96+
"typeid"_a, py::kw_only(), "replace"_a = false,
9597
"Register a type caster for casting MLIR types to custom user types.");
98+
m.def(
99+
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
100+
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
101+
return py::cpp_function(
102+
[mlirTypeID, replace](py::object valueCaster) -> py::object {
103+
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
104+
replace);
105+
return valueCaster;
106+
});
107+
},
108+
"typeid"_a, py::kw_only(), "replace"_a = false,
109+
"Register a value caster for casting MLIR values to custom user values.");
96110

97111
// Define and populate IR submodule.
98112
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");

mlir/lib/Bindings/Python/PybindUtils.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1111

1212
#include "mlir-c/Support.h"
13+
#include "llvm/ADT/STLExtras.h"
1314
#include "llvm/ADT/Twine.h"
1415
#include "llvm/Support/DataTypes.h"
1516

@@ -228,6 +229,11 @@ class Sliceable {
228229
return linearIndex;
229230
}
230231

232+
/// Trait to check if T provides a `maybeDownCast` method.
233+
/// Note, you need the & to detect inherited members.
234+
template <typename T, typename... Args>
235+
using has_maybe_downcast = decltype(&T::maybeDownCast);
236+
231237
/// Returns the element at the given slice index. Supports negative indices
232238
/// by taking elements in inverse order. Returns a nullptr object if out
233239
/// of bounds.
@@ -239,8 +245,13 @@ class Sliceable {
239245
return {};
240246
}
241247

242-
return pybind11::cast(
243-
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
248+
if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
249+
return static_cast<Derived *>(this)
250+
->getRawElement(linearizeIndex(index))
251+
.maybeDownCast();
252+
else
253+
return pybind11::cast(
254+
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
244255
}
245256

246257
/// Returns a new instance of the pseudo-container restricted to the given

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
# Provide a convenient name for sub-packages to resolve the main C-extension
66
# with a relative import.
77
from .._mlir_libs import _mlir as _cext
8-
from typing import Sequence as _Sequence, Union as _Union
8+
from typing import (
9+
Sequence as _Sequence,
10+
Type as _Type,
11+
TypeVar as _TypeVar,
12+
Union as _Union,
13+
)
914

1015
__all__ = [
1116
"equally_sized_accessor",
@@ -123,3 +128,9 @@ def get_op_result_or_op_results(
123128
if len(op.results) > 0
124129
else op
125130
)
131+
132+
133+
# This is the standard way to indicate subclass/inheritance relationship
134+
# see the typing.Type doc string.
135+
_U = _TypeVar("_U", bound=_cext.ir.Value)
136+
SubClassValueT = _Type[_U]

mlir/python/mlir/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._mlir_libs._mlir.ir import *
66
from ._mlir_libs._mlir.ir import _GlobalDebug
7-
from ._mlir_libs._mlir import register_type_caster
7+
from ._mlir_libs._mlir import register_type_caster, register_value_caster
88

99

1010
# Convenience decorator for registering user-friendly Attribute builders.

0 commit comments

Comments
 (0)