Skip to content

[mlir][python] generate value builders #68308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
"get_op_result_or_op_results",
"segmented_accessor",
]

Expand Down Expand Up @@ -167,3 +168,17 @@ def get_op_results_or_values(
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]


def get_op_result_or_op_results(
op: _Union[_cext.ir.OpView, _cext.ir.Operation],
) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
if isinstance(op, _cext.ir.OpView):
op = op.operation
return (
list(get_op_results_or_values(op))
if len(op.results) > 1
else get_op_result_or_value(op)
if len(op.results) > 0
else op
)
5 changes: 3 additions & 2 deletions mlir/python/mlir/dialects/_scf_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Any, Optional, Sequence, Union
from typing import Optional, Sequence, Union

from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
Expand All @@ -25,7 +26,7 @@ def __init__(
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None
ip=None,
):
"""Creates an SCF `for` operation.

Expand Down
38 changes: 38 additions & 0 deletions mlir/python/mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,42 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional, Sequence

from ._scf_ops_gen import *
from .arith import constant
from ..ir import *


def for_(
start,
stop=None,
step=None,
iter_args: Optional[Sequence[Value]] = None,
*,
loc=None,
ip=None,
):
if step is None:
step = 1
if stop is None:
stop = start
start = 0
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
p = constant(p)
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p

for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
iv = for_op.induction_variable
iter_args = tuple(for_op.inner_iter_args)
with InsertionPoint(for_op.body):
if len(iter_args) > 1:
yield iv, iter_args
elif len(iter_args) == 1:
yield iv, iter_args[0]
else:
yield iv
70 changes: 69 additions & 1 deletion mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}

// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
Expand Down Expand Up @@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}

// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))


// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
Expand Down Expand Up @@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unitAttr, I32Attr:$in);
}

// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
Expand Down Expand Up @@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}

// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
Expand All @@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}

// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
Expand All @@ -220,13 +235,19 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}

// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, Variadic<AnyType>);
}

// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.empty"
Expand All @@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))

// CHECK: def empty(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
Expand All @@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}

// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
// CHECK: def __init__(self, *, loc=None, ip=None):
Expand All @@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}

// CHECK: def infer_result_types_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
Expand Down Expand Up @@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}

// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
Expand All @@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]

}

// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
Expand Down Expand Up @@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}

// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
Expand Down Expand Up @@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}

// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
Expand All @@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}

// CHECK: def python_keyword(in_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
Expand All @@ -416,13 +460,19 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}

// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
let arguments = (ins AnyType:$in1, AnyType:$in2);
let results = (outs Variadic<AnyType>:$res);
}

// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))


// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
Expand All @@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}

// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
Expand All @@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}

// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.simple"
Expand Down Expand Up @@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}

// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))

// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
Expand All @@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}

// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))

// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
def VariadicRegionOp : TestOp<"variadic_region"> {
Expand All @@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}

// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
Expand All @@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
}

// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
23 changes: 22 additions & 1 deletion mlir/test/python/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import scf
from mlir.dialects import builtin


def constructAndPrintInModule(f):
Expand Down Expand Up @@ -54,6 +53,28 @@ def induction_var(lb, ub, step):
# CHECK: scf.yield %[[IV]]


# CHECK-LABEL: TEST: testForSugar
@constructAndPrintInModule
def testForSugar():
index_type = IndexType.get()
range = scf.for_

@func.FuncOp.from_py_func(index_type, index_type, index_type)
def range_loop(lb, ub, step):
for i in range(lb, ub, step):
add = arith.addi(i, i)
scf.yield_([])
return


# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
# CHECK: }
# CHECK: return
# CHECK: }


@constructAndPrintInModule
def testOpsAsArguments():
index_type = IndexType.get()
Expand Down
Loading