Skip to content

Commit 31ebe98

Browse files
authored
[mlir][c] Expose AsmState. (#66693)
Enable usage where capturing AsmState is good (e.g., avoiding creating AsmState over and over again when walking IR and printing). This also only changes one C API to verify plumbing. But using the AsmState makes the cost more explicit than the flags interface (which hides the traversals and construction here) and also enables a more efficient usage C side.
1 parent 6af39d9 commit 31ebe98

File tree

5 files changed

+83
-4
lines changed

5 files changed

+83
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ extern "C" {
4848
}; \
4949
typedef struct name name
5050

51+
DEFINE_C_API_STRUCT(MlirAsmState, void);
5152
DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
5253
DEFINE_C_API_STRUCT(MlirContext, void);
5354
DEFINE_C_API_STRUCT(MlirDialect, void);
@@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
383384
MLIR_CAPI_EXPORTED void
384385
mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
385386

387+
//===----------------------------------------------------------------------===//
388+
// AsmState API.
389+
// While many of these are simple settings that could be represented in a
390+
// struct, they are wrapped in a heap allocated object and accessed via
391+
// functions to maximize the possibility of compatibility over time.
392+
//===----------------------------------------------------------------------===//
393+
394+
/// Creates new AsmState, as with AsmState the IR should not be mutated
395+
/// in-between using this state.
396+
/// Must be freed with a call to mlirAsmStateDestroy().
397+
// TODO: This should be expanded to handle location & resouce map.
398+
MLIR_CAPI_EXPORTED MlirAsmState
399+
mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags);
400+
401+
/// Creates new AsmState from value.
402+
/// Must be freed with a call to mlirAsmStateDestroy().
403+
// TODO: This should be expanded to handle location & resouce map.
404+
MLIR_CAPI_EXPORTED MlirAsmState
405+
mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags);
406+
407+
/// Destroys printing flags created with mlirAsmStateCreate.
408+
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state);
409+
386410
//===----------------------------------------------------------------------===//
387411
// Op Printing flags API.
388412
// While many of these are simple settings that could be represented in a
@@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
815839

816840
/// Prints a value as an operand (i.e., the ValueID).
817841
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
818-
MlirOpPrintingFlags flags,
842+
MlirAsmState state,
819843
MlirStringCallback callback,
820844
void *userData);
821845

mlir/include/mlir/CAPI/IR.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/MLIRContext.h"
2222
#include "mlir/IR/Operation.h"
2323

24+
DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
2425
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
2526
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
2627
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) {
34303430
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
34313431
if (useLocalScope)
34323432
mlirOpPrintingFlagsUseLocalScope(flags);
3433-
mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
3433+
MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
3434+
mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
34343435
printAccum.getUserData());
34353436
mlirOpPrintingFlagsDestroy(flags);
3437+
mlirAsmStateDestroy(state);
34363438
return printAccum.join();
34373439
},
34383440
py::arg("use_local_scope") = false, kGetNameAsOperand)

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
138138
delete unwrap(registry);
139139
}
140140

141+
//===----------------------------------------------------------------------===//
142+
// AsmState API.
143+
//===----------------------------------------------------------------------===//
144+
145+
MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
146+
MlirOpPrintingFlags flags) {
147+
return wrap(new AsmState(unwrap(op), *unwrap(flags)));
148+
}
149+
150+
static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
151+
do {
152+
// If we are printing local scope, stop at the first operation that is
153+
// isolated from above.
154+
if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
155+
break;
156+
157+
// Otherwise, traverse up to the next parent.
158+
Operation *parentOp = op->getParentOp();
159+
if (!parentOp)
160+
break;
161+
op = parentOp;
162+
} while (true);
163+
return op;
164+
}
165+
166+
MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
167+
MlirOpPrintingFlags flags) {
168+
Operation *op;
169+
mlir::Value val = unwrap(value);
170+
if (auto result = llvm::dyn_cast<OpResult>(val)) {
171+
op = result.getOwner();
172+
} else {
173+
op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
174+
if (!op) {
175+
emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
176+
return {nullptr};
177+
}
178+
}
179+
op = findParent(op, unwrap(flags)->shouldUseLocalScope());
180+
return wrap(new AsmState(op, *unwrap(flags)));
181+
}
182+
183+
/// Destroys printing flags created with mlirAsmStateCreate.
184+
void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }
185+
141186
//===----------------------------------------------------------------------===//
142187
// Printing flags API.
143188
//===----------------------------------------------------------------------===//
@@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
840885
unwrap(value).print(stream);
841886
}
842887

843-
void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
888+
void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
844889
MlirStringCallback callback, void *userData) {
845890
detail::CallbackOstream stream(callback, userData);
846891
Value cppValue = unwrap(value);
847-
cppValue.printAsOperand(stream, *unwrap(flags));
892+
cppValue.printAsOperand(stream, *unwrap(state));
848893
}
849894

850895
MlirOpOperand mlirValueGetFirstUse(MlirValue value) {

mlir/test/CAPI/ir.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
487487
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
488488
// clang-format on
489489

490+
MlirAsmState state = mlirAsmStateCreateForOperation(parentOperation, flags);
491+
fprintf(stderr, "With state: |");
492+
mlirValuePrintAsOperand(value, state, printToStderr, NULL);
493+
// CHECK: With state: |%0|
494+
fprintf(stderr, "|\n");
495+
mlirAsmStateDestroy(state);
496+
490497
mlirOpPrintingFlagsDestroy(flags);
491498
}
492499

0 commit comments

Comments
 (0)