diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 7655629a55425..895c3228139b3 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -13,6 +13,7 @@ "get_default_loc_context", "get_op_result_or_value", "get_op_results_or_values", + "get_op_result_or_op_results", "segmented_accessor", ] @@ -167,3 +168,17 @@ def get_op_results_or_values( return arg.results else: return [get_op_result_or_value(element) for element in arg] + + +def get_op_result_or_op_results( + op: _Union[_cext.ir.OpView, _cext.ir.Operation], +) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: + if isinstance(op, _cext.ir.OpView): + op = op.operation + return ( + list(get_op_results_or_values(op)) + if len(op.results) > 1 + else get_op_result_or_value(op) + if len(op.results) > 0 + else op + ) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index 4b0a31327abb0..89cc8a19895c7 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -7,7 +7,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Any, Optional, Sequence, Union +from typing import Optional, Sequence, Union + from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, @@ -25,7 +26,7 @@ def __init__( iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, *, loc=None, - ip=None + ip=None, ): """Creates an SCF `for` operation. diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 302a49d56c211..49685ca2271fc 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -2,4 +2,42 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional, Sequence + from ._scf_ops_gen import * +from .arith import constant +from ..ir import * + + +def for_( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[Value]] = None, + *, + loc=None, + ip=None, +): + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + params = [start, stop, step] + for i, p in enumerate(params): + if isinstance(p, int): + p = constant(p) + elif isinstance(p, float): + raise ValueError(f"{p=} must be int.") + params[i] = p + + for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) + iv = for_op.induction_variable + iter_args = tuple(for_op.inner_iter_args) + with InsertionPoint(for_op.body): + if len(iter_args) > 1: + yield iv, iter_args + elif len(iter_args) == 1: + yield iv, iter_args[0] + else: + yield iv diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index a131209fa45cb..8ca23fa9f45c4 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", Optional:$variadic2); } +// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" @@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", Variadic:$variadic2); } +// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOp(_ods_ir.OpView): @@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> { UnitAttr:$unitAttr, I32Attr:$in); } +// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands" @@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } +// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs" @@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { let results = (outs); } +// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)) + // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { // CHECK: def __init__(self, type_, *, loc=None, ip=None): @@ -220,6 +235,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu let results = (outs AnyType:$res, AnyType); } +// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)) + // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None): @@ -227,6 +245,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir let results = (outs AnyType:$res, Variadic); } +// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.empty" @@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">; // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) +// CHECK: def empty(*, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { @@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { let results = (outs I32:$i32, F32:$f32); } +// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)) + // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { // CHECK: def __init__(self, *, loc=None, ip=None): @@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> let results = (outs AnyType, AnyType, AnyType); } +// CHECK: def infer_result_types_op(*, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" @@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> { let results = (outs I32:$i32, AnyFloat, I64:$i64); } +// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneOptionalOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand" @@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { // CHECK: @builtins.property // CHECK: def optional(self): // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1] - } +// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" @@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { let arguments = (ins AnyType:$non_variadic, Variadic:$variadic); } +// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result" @@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> { let results = (outs Variadic:$variadic, AnyType:$non_variadic); } +// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class PythonKeywordOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.python_keyword" @@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: return self.operation.operands[0] let arguments = (ins AnyType:$in); } + +// CHECK: def python_keyword(in_, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)) + // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None): @@ -416,6 +460,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { let results = (outs AnyType:$res); } +// CHECK: def same_results(in1, in2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)) + // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None): @@ -423,6 +470,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu let results = (outs Variadic:$res); } +// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView): @@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand", Variadic:$variadic2); } +// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result" @@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result", Variadic:$variadic2); } +// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SimpleOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" @@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> { let results = (outs I64:$i64, AnyFloat:$f64); } +// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)) + // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region" def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { @@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { // CHECK: return self.regions[2:] } +// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) + // CHECK: class VariadicRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_region" def VariadicRegionOp : TestOp<"variadic_region"> { @@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> { // CHECK: return self.regions[0:] } +// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSuccessorsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.with_successors" @@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> { let successors = (successor AnySuccessor:$successor, VariadicSuccessor:$successors); } + +// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) +// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) \ No newline at end of file diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 8cb55fdf6a1eb..414307d819151 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -4,7 +4,6 @@ from mlir.dialects import arith from mlir.dialects import func from mlir.dialects import scf -from mlir.dialects import builtin def constructAndPrintInModule(f): @@ -54,6 +53,28 @@ def induction_var(lb, ub, step): # CHECK: scf.yield %[[IV]] +# CHECK-LABEL: TEST: testForSugar +@constructAndPrintInModule +def testForSugar(): + index_type = IndexType.get() + range = scf.for_ + + @func.FuncOp.from_py_func(index_type, index_type, index_type) + def range_loop(lb, ub, step): + for i in range(lb, ub, step): + add = arith.addi(i, i) + scf.yield_([]) + return + + +# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) { +# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] +# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index +# CHECK: } +# CHECK: return +# CHECK: } + + @constructAndPrintInModule def testOpsAsArguments(): index_type = IndexType.get() diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0b5df7ab70ddd..fc094a1829ff7 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext -from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results _ods_ir = _ods_cext.ir try: @@ -39,6 +39,7 @@ except ImportError: _ods_ext_module = None import builtins +from typing import Sequence as _Sequence, Union as _Union )Py"; @@ -260,11 +261,16 @@ constexpr const char *attributeDeleterTemplate = R"Py( del self.operation.attributes["{1}"] )Py"; -constexpr const char *regionAccessorTemplate = R"PY( +constexpr const char *regionAccessorTemplate = R"Py( @builtins.property def {0}(self): return self.regions[{1}] -)PY"; +)Py"; + +constexpr const char *valueBuilderTemplate = R"Py( +def {0}({2}) -> {4}: + return _get_op_result_or_op_results({1}({3})) +)Py"; static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); @@ -609,9 +615,7 @@ populateBuilderArgsResults(const Operator &op, static void populateBuilderArgs(const Operator &op, llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames, - llvm::SmallVectorImpl &successorArgNames) { - + llvm::SmallVectorImpl &operandNames) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) @@ -734,11 +738,11 @@ populateBuilderLinesOperand(const Operator &op, /// attribute: /// - {0} is the name of the attribute from which to derive the types. constexpr const char *deriveTypeFromAttrTemplate = - R"PY(_ods_result_type_source_attr = attributes["{0}"] + R"Py(_ods_result_type_source_attr = attributes["{0}"] _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else - _ods_result_type_source_attr.type))PY"; + _ods_result_type_source_attr.type))Py"; /// Python code template appending {0} type {1} times to the results list. constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; @@ -836,11 +840,14 @@ populateBuilderRegions(const Operator &op, } /// Emits a default builder constructing an operation from the list of its -/// result types, followed by a list of its operands. -static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { +/// result types, followed by a list of its operands. Returns vector +/// of fully built functionArgs for downstream users (to save having to +/// rebuild anew). +static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, + raw_ostream &os) { // If we are asked to skip default builders, comply. if (op.skipDefaultBuilders()) - return; + return {}; llvm::SmallVector builderArgs; llvm::SmallVector builderLines; @@ -850,7 +857,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { op.getNumNativeAttributes() + op.getNumSuccessors()); populateBuilderArgsResults(op, builderArgs); size_t numResultArgs = builderArgs.size(); - populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); + populateBuilderArgs(op, builderArgs, operandArgNames); size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); @@ -921,6 +928,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n "), llvm::join(initArgs, ", ")); + return llvm::to_vector<8>( + llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); })); } static void emitSegmentSpec( @@ -968,6 +977,45 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) { } } +/// Emits builder that extracts results from op +static void emitValueBuilder(const Operator &op, + llvm::SmallVector functionArgs, + raw_ostream &os) { + // If we are asked to skip default builders, comply. + if (op.skipDefaultBuilders()) + return; + auto name = sanitizeName(op.getOperationName()); + iterator_range splitName = llvm::split(name, "."); + // Params with (possibly) default args. + auto valueBuilderParams = + llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { + llvm::SmallVector argMaybeDefault = + llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); + auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); + if (argMaybeDefault.size() == 2) + return arg + "=" + argMaybeDefault[1].str(); + return arg; + }); + // Actual args passed to op builder (e.g., opParam=op_param). + auto opBuilderArgs = llvm::map_range( + llvm::make_filter_range(functionArgs, + [](const std::string &s) { return s != "*"; }), + [](const std::string &arg) { + auto lhs = *llvm::split(arg, "=").begin(); + return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); + }); + os << llvm::formatv( + valueBuilderTemplate, + // Drop dialect name and then sanitize again (to catch e.g. func.return). + sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")), + op.getCppClassName(), llvm::join(valueBuilderParams, ", "), + llvm::join(opBuilderArgs, ", "), + (op.getNumResults() > 1 + ? "_Sequence[_ods_ir.OpResult]" + : (op.getNumResults() > 0 ? "_ods_ir.OpResult" + : "_ods_ir.Operation"))); +} + /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, raw_ostream &os) { os << llvm::formatv(opClassTemplate, op.getCppClassName(), @@ -982,11 +1030,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) { } emitRegionAttributes(op, os); - emitDefaultOpBuilder(op, os); + llvm::SmallVector functionArgs = emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitAttributeAccessors(op, os); emitResultAccessors(op, os); emitRegionAccessors(op, os); + emitValueBuilder(op, functionArgs, os); } /// Emits bindings for the dialect specified in the command line, including file