Skip to content

[mlir][python] add scf.parallel/scf.forall helpers #150243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
util.py

# The main _mlir module has submodules: include stubs from each.
_mlir_libs/_mlir/__init__.pyi
Expand Down
2 changes: 1 addition & 1 deletion mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _dispatch_mixed_values(
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)

return (dynamic_values, packed_values, static_values)
return dynamic_values, packed_values, static_values


def _get_value_or_attribute_value(
Expand Down
42 changes: 20 additions & 22 deletions mlir/python/mlir/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
from ..util import is_integer_type, is_index_type, is_float_type
from array import array as _array
from typing import overload


try:
from ..ir import *
from ._ods_common import (
Expand All @@ -21,26 +23,6 @@
raise RuntimeError("Error loading imports from extension module") from e


def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True


def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)


def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])


def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])


@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
Expand Down Expand Up @@ -96,9 +78,9 @@ def value(self):

@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
if is_integer_type(self.type) or is_index_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
elif is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
Expand All @@ -108,3 +90,19 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))


def index_cast(
in_: Value,
to: Type = None,
*,
out: Type = None,
loc: Location = None,
ip: InsertionPoint = None,
) -> Value:
if bool(to) != bool(out):
raise ValueError("either `to` or `out` must be set but not both")
res_type = out or to
if res_type is None:
res_type = IndexType.get()
return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip))
106 changes: 38 additions & 68 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from typing import Callable, Dict, List, Sequence, Tuple, Union

from .....ir import *
from .....util import (
is_complex_type,
is_float_type,
is_index_type,
is_integer_type,
get_floating_point_width,
)

from .... import func
from .... import linalg
Expand Down Expand Up @@ -412,21 +419,21 @@ def _cast(
)
if operand.type == to_type:
return operand
if _is_integer_type(to_type):
if is_integer_type(to_type):
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
elif _is_floating_point_type(to_type):
elif is_float_type(to_type):
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)

def _cast_to_integer(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
if _is_floating_point_type(operand_type):
if is_float_type(operand_type):
if is_unsigned_cast:
return arith.FPToUIOp(to_type, operand).result
return arith.FPToSIOp(to_type, operand).result
if _is_index_type(operand_type):
if is_index_type(operand_type):
return arith.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
Expand All @@ -444,13 +451,13 @@ def _cast_to_floating_point(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
operand_type = operand.type
if _is_integer_type(operand_type):
if is_integer_type(operand_type):
if is_unsigned_cast:
return arith.UIToFPOp(to_type, operand).result
return arith.SIToFPOp(to_type, operand).result
# Assume FloatType.
to_width = _get_floating_point_width(to_type)
from_width = _get_floating_point_width(operand_type)
to_width = get_floating_point_width(to_type)
from_width = get_floating_point_width(operand_type)
if to_width > from_width:
return arith.ExtFOp(to_type, operand).result
elif to_width < from_width:
Expand All @@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)

def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")

def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")

def _unary_abs(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.AbsFOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")

def _unary_ceil(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.CeilOp(x).result
raise NotImplementedError("Unsupported 'ceil' operand: {x}")

def _unary_floor(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.FloorOp(x).result
raise NotImplementedError("Unsupported 'floor' operand: {x}")

def _unary_negf(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return arith.NegFOp(x).result
if _is_complex_type(x.type):
if is_complex_type(x.type):
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")

def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")

def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")

def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")

def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")

def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")

def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")

def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")

Expand Down Expand Up @@ -609,40 +616,3 @@ def _add_type_mapping(
)
type_mapping[name] = element_or_self_type
block_arg_types.append(element_or_self_type)


def _is_complex_type(t: Type) -> bool:
return ComplexType.isinstance(t)


def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
return (
F64Type.isinstance(t)
or F32Type.isinstance(t)
or F16Type.isinstance(t)
or BF16Type.isinstance(t)
)


def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)


def _is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)


def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
if F64Type.isinstance(t):
return 64
if F32Type.isinstance(t):
return 32
if F16Type.isinstance(t):
return 16
if BF16Type.isinstance(t):
return 16
raise NotImplementedError(f"Unhandled floating point type switch {t}")
5 changes: 3 additions & 2 deletions mlir/python/mlir/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@

from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from .arith import ConstantOp
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
from ..util import is_integer_like_type


def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
and is_integer_like_type(i.type)
)


Expand Down
Loading