Skip to content

Commit 7545371

Browse files
authored
[mlir][python] Expose AsmState python side. (#66819)
This does basic plumbing, ideally want a context approach to reduce needing to thread these manually, but the current is useful even in that state. Made Value.get_name change backwards compatible, so one could either set a field or create a state to pass in.
1 parent f5b42ea commit 7545371

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,19 +3430,35 @@ void mlir::python::populateIRCore(py::module &m) {
34303430
kValueDunderStrDocstring)
34313431
.def(
34323432
"get_name",
3433-
[](PyValue &self, bool useLocalScope) {
3433+
[](PyValue &self, std::optional<bool> useLocalScope,
3434+
std::optional<std::reference_wrapper<PyAsmState>> state) {
34343435
PyPrintAccumulator printAccum;
3435-
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3436-
if (useLocalScope)
3437-
mlirOpPrintingFlagsUseLocalScope(flags);
3438-
MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
3439-
mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
3436+
MlirOpPrintingFlags flags;
3437+
MlirAsmState valueState;
3438+
// Use state if provided, else create a new state.
3439+
if (state) {
3440+
valueState = state.value().get().get();
3441+
// Don't allow setting using local scope and state at same time.
3442+
if (useLocalScope.has_value())
3443+
throw py::value_error(
3444+
"setting AsmState and local scope together not supported");
3445+
} else {
3446+
flags = mlirOpPrintingFlagsCreate();
3447+
if (useLocalScope.value_or(false))
3448+
mlirOpPrintingFlagsUseLocalScope(flags);
3449+
valueState = mlirAsmStateCreateForValue(self.get(), flags);
3450+
}
3451+
mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
34403452
printAccum.getUserData());
3441-
mlirOpPrintingFlagsDestroy(flags);
3442-
mlirAsmStateDestroy(state);
3453+
// Release state if allocated locally.
3454+
if (!state) {
3455+
mlirOpPrintingFlagsDestroy(flags);
3456+
mlirAsmStateDestroy(valueState);
3457+
}
34433458
return printAccum.join();
34443459
},
3445-
py::arg("use_local_scope") = false, kGetNameAsOperand)
3460+
py::arg("use_local_scope") = std::nullopt,
3461+
py::arg("state") = std::nullopt, kGetNameAsOperand)
34463462
.def_property_readonly(
34473463
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
34483464
.def(
@@ -3461,6 +3477,10 @@ void mlir::python::populateIRCore(py::module &m) {
34613477
PyOpResult::bind(m);
34623478
PyOpOperand::bind(m);
34633479

3480+
py::class_<PyAsmState>(m, "AsmState", py::module_local())
3481+
.def(py::init<PyValue &, bool>(), py::arg("value"),
3482+
py::arg("use_local_scope") = false);
3483+
34643484
//----------------------------------------------------------------------------
34653485
// Mapping of SymbolTable.
34663486
//----------------------------------------------------------------------------

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,31 @@ class PyRegion {
748748
MlirRegion region;
749749
};
750750

751+
/// Wrapper around an MlirAsmState.
752+
class PyAsmState {
753+
public:
754+
PyAsmState(MlirValue value, bool useLocalScope) {
755+
flags = mlirOpPrintingFlagsCreate();
756+
// The OpPrintingFlags are not exposed Python side, create locally and
757+
// associate lifetime with the state.
758+
if (useLocalScope)
759+
mlirOpPrintingFlagsUseLocalScope(flags);
760+
state = mlirAsmStateCreateForValue(value, flags);
761+
}
762+
~PyAsmState() {
763+
mlirOpPrintingFlagsDestroy(flags);
764+
}
765+
// Delete copy constructors.
766+
PyAsmState(PyAsmState &other) = delete;
767+
PyAsmState(const PyAsmState &other) = delete;
768+
769+
MlirAsmState get() { return state; }
770+
771+
private:
772+
MlirAsmState state;
773+
MlirOpPrintingFlags flags;
774+
};
775+
751776
/// Wrapper around an MlirBlock.
752777
/// Blocks are managed completely by their containing operation. Unlike the
753778
/// C++ API, the python API does not support detached blocks.

mlir/test/python/ir/value.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# RUN: %PYTHON %s | FileCheck %s
1+
# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
22

33
import gc
44
from mlir.ir import *
@@ -199,6 +199,16 @@ def testValuePrintAsOperand():
199199
# CHECK: %[[VAL4]]
200200
print(value4.get_name())
201201

202+
print("With AsmState")
203+
# CHECK-LABEL: With AsmState
204+
state = AsmState(value3, use_local_scope=True)
205+
# CHECK: %0
206+
print(value3.get_name(state=state))
207+
# CHECK: %1
208+
print(value4.get_name(state=state))
209+
210+
print("With use_local_scope")
211+
# CHECK-LABEL: With use_local_scope
202212
# CHECK: %0
203213
print(value3.get_name(use_local_scope=True))
204214
# CHECK: %1

0 commit comments

Comments
 (0)