diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..f01798f48ff86 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python passmanager.py rewrite.py dialects/_ods_common.py + util.py # The main _mlir module has submodules: include stubs from each. _mlir_libs/_mlir/__init__.pyi diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 10abd06ff266e..da291d824aa8c 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -245,7 +245,7 @@ def _dispatch_mixed_values( dynamic_values.append(size) static_values = DenseI64ArrayAttr.get(static_values) - return (dynamic_values, packed_values, static_values) + return dynamic_values, packed_values, static_values def _get_value_or_attribute_value( diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 92da5df9bce66..418ec0bdbb6db 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -5,9 +5,11 @@ from ._arith_ops_gen import * from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * +from ..util import is_integer_type, is_index_type, is_float_type from array import array as _array from typing import overload + try: from ..ir import * from ._ods_common import ( @@ -21,26 +23,6 @@ raise RuntimeError("Error loading imports from extension module") from e -def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True - - -def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) - - -def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) - - -def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) - - @_ods_cext.register_operation(_Dialect, replace=True) class ConstantOp(ConstantOp): """Specialization for the constant op class.""" @@ -96,9 +78,9 @@ def value(self): @property def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): + if is_integer_type(self.type) or is_index_type(self.type): return IntegerAttr(self.value).value - elif _is_float_type(self.type): + elif is_float_type(self.type): return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") @@ -108,3 +90,19 @@ def constant( result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None ) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) + + +def index_cast( + in_: Value, + to: Type = None, + *, + out: Type = None, + loc: Location = None, + ip: InsertionPoint = None, +) -> Value: + if bool(to) != bool(out): + raise ValueError("either `to` or `out` must be set but not both") + res_type = out or to + if res_type is None: + res_type = IndexType.get() + return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip)) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 254458a978828..cae70fc03b9d6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -5,6 +5,13 @@ from typing import Callable, Dict, List, Sequence, Tuple, Union from .....ir import * +from .....util import ( + is_complex_type, + is_float_type, + is_index_type, + is_integer_type, + get_floating_point_width, +) from .... import func from .... import linalg @@ -412,9 +419,9 @@ def _cast( ) if operand.type == to_type: return operand - if _is_integer_type(to_type): + if is_integer_type(to_type): return self._cast_to_integer(to_type, operand, is_unsigned_cast) - elif _is_floating_point_type(to_type): + elif is_float_type(to_type): return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) def _cast_to_integer( @@ -422,11 +429,11 @@ def _cast_to_integer( ) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type - if _is_floating_point_type(operand_type): + if is_float_type(operand_type): if is_unsigned_cast: return arith.FPToUIOp(to_type, operand).result return arith.FPToSIOp(to_type, operand).result - if _is_index_type(operand_type): + if is_index_type(operand_type): return arith.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width @@ -444,13 +451,13 @@ def _cast_to_floating_point( self, to_type: Type, operand: Value, is_unsigned_cast: bool ) -> Value: operand_type = operand.type - if _is_integer_type(operand_type): + if is_integer_type(operand_type): if is_unsigned_cast: return arith.UIToFPOp(to_type, operand).result return arith.SIToFPOp(to_type, operand).result # Assume FloatType. - to_width = _get_floating_point_width(to_type) - from_width = _get_floating_point_width(operand_type) + to_width = get_floating_point_width(to_type) + from_width = get_floating_point_width(operand_type) if to_width > from_width: return arith.ExtFOp(to_type, operand).result elif to_width < from_width: @@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) def _unary_exp(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _unary_log(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _unary_abs(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return math.AbsFOp(x).result raise NotImplementedError("Unsupported 'abs' operand: {x}") def _unary_ceil(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return math.CeilOp(x).result raise NotImplementedError("Unsupported 'ceil' operand: {x}") def _unary_floor(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return math.FloorOp(x).result raise NotImplementedError("Unsupported 'floor' operand: {x}") def _unary_negf(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if is_float_type(x.type): return arith.NegFOp(x).result - if _is_complex_type(x.type): + if is_complex_type(x.type): return complex.NegOp(x).result raise NotImplementedError("Unsupported 'negf' operand: {x}") def _binary_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.AddIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if is_complex_type(lhs.type): return complex.AddOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") def _binary_sub(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.SubFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if is_complex_type(lhs.type): return complex.SubOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") def _binary_mul(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.MulFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if is_complex_type(lhs.type): return complex.MulOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.MaximumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.MaximumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.MinimumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if is_float_type(lhs.type): return arith.MinimumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if is_integer_type(lhs.type) or is_index_type(lhs.type): return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") @@ -609,40 +616,3 @@ def _add_type_mapping( ) type_mapping[name] = element_or_self_type block_arg_types.append(element_or_self_type) - - -def _is_complex_type(t: Type) -> bool: - return ComplexType.isinstance(t) - - -def _is_floating_point_type(t: Type) -> bool: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - return ( - F64Type.isinstance(t) - or F32Type.isinstance(t) - or F16Type.isinstance(t) - or BF16Type.isinstance(t) - ) - - -def _is_integer_type(t: Type) -> bool: - return IntegerType.isinstance(t) - - -def _is_index_type(t: Type) -> bool: - return IndexType.isinstance(t) - - -def _get_floating_point_width(t: Type) -> int: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - if F64Type.isinstance(t): - return 64 - if F32Type.isinstance(t): - return 32 - if F16Type.isinstance(t): - return 16 - if BF16Type.isinstance(t): - return 16 - raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index bc9a3a52728ad..2130a1966d88e 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -7,8 +7,9 @@ from ._memref_ops_gen import * from ._ods_common import _dispatch_mixed_values, MixedValues -from .arith import ConstantOp, _is_integer_like_type +from .arith import ConstantOp from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation +from ..util import is_integer_like_type def _is_constant_int_like(i): @@ -16,7 +17,7 @@ def _is_constant_int_like(i): isinstance(i, Value) and isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp) - and _is_integer_like_type(i.type) + and is_integer_like_type(i.type) ) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 678ceeebac204..ff020fffcbed9 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -5,10 +5,12 @@ from ._scf_ops_gen import * from ._scf_ops_gen import _Dialect -from .arith import constant +from . import arith +from ..extras.meta import region_op, region_adder try: from ..ir import * + from ..util import is_index_type from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, @@ -237,7 +239,7 @@ def for_( params = [start, stop, step] for i, p in enumerate(params): if isinstance(p, int): - p = constant(IndexType.get(), p) + p = arith.constant(IndexType.get(), p) elif isinstance(p, float): raise ValueError(f"{p=} must be int.") params[i] = p @@ -254,3 +256,136 @@ def for_( yield iv, iter_args[0], for_op.results[0] else: yield iv + + +def _parfor(op_ctor): + def _base( + lower_bounds, upper_bounds=None, steps=None, *, loc=None, ip=None, **kwargs + ): + if upper_bounds is None: + upper_bounds = lower_bounds + lower_bounds = [0] * len(upper_bounds) + if steps is None: + steps = [1] * len(lower_bounds) + + params = [lower_bounds, upper_bounds, steps] + for i, p in enumerate(params): + for j, pp in enumerate(p): + if isinstance(pp, int): + pp = arith.constant(IndexType.get(), pp) + assert isinstance(pp, Value), f"expected ir.Value, got {type(pp)=}" + if not is_index_type(pp.type): + pp = arith.index_cast(pp) + p[j] = pp + params[i] = p + + return op_ctor(*params, loc=loc, ip=ip, **kwargs) + + return _base + + +def _parfor_cm(op_ctor): + def _base(*args, **kwargs): + for_op = _parfor(op_ctor)(*args, **kwargs) + block = for_op.regions[0].blocks[0] + block_args = tuple(block.arguments) + with InsertionPoint(block): + yield block_args + + return _base + + +forall = _parfor_cm(ForallOp) + + +class ParallelOp(ParallelOp): + def __init__( + self, + lower_bounds: Sequence[Union[Operation, OpView, Value, int]], + upper_bounds: Sequence[Union[Operation, OpView, Value, int]], + steps: Sequence[Union[Value, int]], + inits: Optional[Sequence[Union[Operation, OpView, Sequence[Value]]]] = None, + *, + loc=None, + ip=None, + ): + assert len(lower_bounds) == len(upper_bounds) == len(steps) + if inits is None: + inits = [] + results = [i.type for i in inits] + iv_types = [IndexType.get()] * len(lower_bounds) + super().__init__( + results, + lower_bounds, + upper_bounds, + steps, + inits, + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(*iv_types) + + @property + def body(self): + return self.regions[0].blocks[0] + + @property + def induction_variables(self): + return self.body.arguments + + +parallel = _parfor_cm(ParallelOp) + + +class ReduceOp(ReduceOp): + def __init__( + self, + operands: Sequence[Union[Operation, OpView, Sequence[Value]]], + num_reductions: int, + *, + loc=None, + ip=None, + ): + operands = _get_op_results_or_values(operands) + super().__init__(operands, num_reductions, loc=loc, ip=ip) + for i in range(num_reductions): + self.regions[i].blocks.append(operands[i].type, operands[i].type) + + +def reduce_(*operands, num_reductions=1, loc=None, ip=None): + return ReduceOp(operands, num_reductions, loc=loc, ip=ip) + + +reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs)) + + +@region_adder(terminator=lambda xs: reduce_return(*xs)) +def another_reduce(reduce_op): + for r in reduce_op.regions: + if len(r.blocks[0].operations) == 0: + return r + + +@region_op +def in_parallel(): + return InParallelOp() + + +def parallel_insert_slice( + source: Union[Operation, OpView, Value], + dest: Union[Operation, OpView, Value], + offsets: Optional[Sequence[Union[Operation, OpView, Value, int]]] = None, + sizes: Optional[Sequence[Union[Operation, OpView, Value, int]]] = None, + strides: Optional[Sequence[Union[Operation, OpView, Value, int]]] = None, +): + from . import tensor + + @in_parallel + def foo(): + tensor.parallel_insert_slice( + source, + dest, + offsets, + sizes, + strides, + ) diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 146b5f85d07f5..40b063ca28e96 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -14,7 +14,11 @@ from typing import Sequence, Union from ._ods_common import _cext as _ods_cext -from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results +from ._ods_common import ( + get_op_result_or_op_results as _get_op_result_or_op_results, + _dispatch_mixed_values, + MixedValues, +) @_ods_cext.register_operation(_Dialect, replace=True) @@ -65,3 +69,33 @@ def empty( lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), terminator=lambda args: YieldOp(args[0]), ) + + +def parallel_insert_slice( + source: Union[Operation, OpView, Value], + dest: Union[Operation, OpView, Value], + offsets: MixedValues, + sizes: MixedValues, + strides: MixedValues, +): + if offsets is None: + offsets = [] + if sizes is None: + sizes = [] + if strides is None: + strides = [] + + offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) + sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) + strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) + + return ParallelInsertSliceOp( + source, + dest, + offsets, + sizes, + strides, + static_offsets, + static_sizes, + static_strides, + ) diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py index 3f2defadf7941..fabe1d8e141ed 100644 --- a/mlir/python/mlir/extras/meta.py +++ b/mlir/python/mlir/extras/meta.py @@ -6,7 +6,7 @@ from functools import wraps from ..dialects._ods_common import get_op_result_or_op_results -from ..ir import Type, InsertionPoint +from ..ir import Type, InsertionPoint, Value def op_region_builder(op, op_region, terminator=None): @@ -81,3 +81,17 @@ def maybe_no_args(*args, **kwargs): return op_decorator(*args, **kwargs) return maybe_no_args + + +def region_adder(terminator=None): + def wrapper(op_region_adder): + def region_adder_decorator(op, *args, **kwargs): + if isinstance(op, Value): + op = op.owner.opview + region = op_region_adder(op, *args, **kwargs) + + return op_region_builder(op, region, terminator) + + return region_adder_decorator + + return wrapper diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py new file mode 100644 index 0000000000000..453b74777014f --- /dev/null +++ b/mlir/python/mlir/util.py @@ -0,0 +1,51 @@ +from .ir import ( + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + IndexType, + IntegerType, + Type, +) + + +def is_complex_type(t: Type) -> bool: + return ComplexType.isinstance(t) + + +def is_float_type(t: Type) -> bool: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return ( + F64Type.isinstance(t) + or F32Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) + + +def is_integer_type(t: Type) -> bool: + return IntegerType.isinstance(t) + + +def is_index_type(t: Type) -> bool: + return IndexType.isinstance(t) + + +def is_integer_like_type(t: Type) -> bool: + return is_integer_type(t) or is_index_type(t) + + +def get_floating_point_width(t: Type) -> int: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index c9af5e7b46db8..0a197c4e673f9 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -4,6 +4,7 @@ from mlir.ir import * import mlir.dialects.arith as arith import mlir.dialects.func as func +from mlir.util import is_float_type, is_integer_like_type from array import array @@ -42,11 +43,9 @@ def testFastMathFlags(): def testArithValue(): def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() - if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): + if is_float_type(lhs.type) and is_float_type(rhs.type): op += "F" - elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( - lhs.type - ): + elif is_integer_like_type(lhs.type) and is_integer_like_type(lhs.type): op += "I" else: raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 62d11d5e189c8..d57b26f46ef0b 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -5,7 +5,7 @@ from mlir.dialects import func from mlir.dialects import memref from mlir.dialects import scf -from mlir.passmanager import PassManager +from mlir.dialects import tensor def constructAndPrintInModule(f): @@ -14,6 +14,7 @@ def constructAndPrintInModule(f): module = Module.create() with InsertionPoint(module.body): f() + assert module.operation.verify() print(module) return f @@ -38,6 +39,179 @@ def forall_loop(tensor): assert loop.verify() +# CHECK-LABEL: TEST: test_forall_insert_slice_no_region_with_for +@constructAndPrintInModule +def test_forall_insert_slice_no_region_with_for(): + i32 = IntegerType.get_signless(32) + f32 = F32Type.get() + ten = tensor.empty([10, 10], i32) + + for i, j, shared_outs in scf.forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]): + one = arith.constant(f32, 1.0) + + scf.parallel_insert_slice( + ten, + shared_outs, + offsets=[i, j], + sizes=[10, 10], + strides=[1, 1], + ) + + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32> + # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) { + # CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32 + # CHECK: scf.forall.in_parallel { + # CHECK: tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32> + # CHECK: } + # CHECK: } + + for ii, jj, shared_outs_1 in scf.forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]): + ten_dynamic = tensor.empty([ii, 10], i32) + scf.parallel_insert_slice( + ten_dynamic, + shared_outs_1, + offsets=[ii, 0], + sizes=[ii, 10], + strides=[ii, 1], + ) + + # CHECK: %[[VAL_12:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_13:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_14:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_15:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_17:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_18:.*]] = scf.forall (%[[VAL_19:.*]], %[[VAL_20:.*]]) = (%[[VAL_12]], %[[VAL_13]]) to (%[[VAL_14]], %[[VAL_15]]) step (%[[VAL_16]], %[[VAL_17]]) shared_outs(%[[VAL_21:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) { + # CHECK: %[[VAL_22:.*]] = tensor.empty(%[[VAL_19]]) : tensor + # CHECK: scf.forall.in_parallel { + # CHECK: tensor.parallel_insert_slice %[[VAL_22]] into %[[VAL_21]]{{\[}}%[[VAL_19]], 0] {{\[}}%[[VAL_19]], 10] {{\[}}%[[VAL_19]], 1] : tensor into tensor<10x10xi32> + # CHECK: } + # CHECK: } + + +# CHECK-LABEL: TEST: test_parange_inits_with_for +@constructAndPrintInModule +def test_parange_inits_with_for(): + i32 = IntegerType.get_signless(32) + f32 = F32Type.get() + tensor_type = RankedTensorType.get([10, 10], f32) + ten = tensor.empty([10, 10], i32) + + for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[ten]): + one = arith.constant(f32, 1.0) + ten2 = tensor.empty([10, 10], i32) + + @scf.reduce(ten2) + def res(lhs: tensor_type, rhs: tensor_type): + return arith.addi(lhs, rhs) + + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32> + # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_7:.*]] = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]]) -> tensor<10x10xi32> { + # CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32 + # CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<10x10xi32> + # CHECK: scf.reduce(%[[VAL_11]] : tensor<10x10xi32>) { + # CHECK: ^bb0(%[[VAL_12:.*]]: tensor<10x10xi32>, %[[VAL_13:.*]]: tensor<10x10xi32>): + # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : tensor<10x10xi32> + # CHECK: scf.reduce.return %[[VAL_14]] : tensor<10x10xi32> + # CHECK: } + # CHECK: } + + +# CHECK-LABEL: TEST: test_parange_inits_with_for_with_two_reduce +@constructAndPrintInModule +def test_parange_inits_with_for_with_two_reduce(): + index_type = IndexType.get() + one = arith.constant(index_type, 1) + + for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[one, one]): + + @scf.reduce(i, j, num_reductions=2) + def res1(lhs: index_type, rhs: index_type): + return arith.addi(lhs, rhs) + + @scf.another_reduce(res1) + def res2(lhs: index_type, rhs: index_type): + return arith.addi(lhs, rhs) + + # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_7:.*]]:2 = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]], %[[VAL_0]]) -> (index, index) { + # CHECK: scf.reduce(%[[VAL_8]], %[[VAL_9]] : index, index) { + # CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index): + # CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index + # CHECK: scf.reduce.return %[[VAL_12]] : index + # CHECK: }, { + # CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index): + # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index + # CHECK: scf.reduce.return %[[VAL_15]] : index + # CHECK: } + # CHECK: } + + +# CHECK-LABEL: TEST: test_parange_inits_with_for_with_three_reduce +@constructAndPrintInModule +def test_parange_inits_with_for_with_three_reduce(): + index_type = IndexType.get() + one = arith.constant(index_type, 1) + + for i, j, k in scf.parallel([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]): + + @scf.reduce(i, j, k, num_reductions=3) + def res1(lhs: index_type, rhs: index_type): + return arith.addi(lhs, rhs) + + @scf.another_reduce(res1) + def res2(lhs: index_type, rhs: index_type): + return arith.addi(lhs, rhs) + + @scf.another_reduce(res2) + def res3(lhs: index_type, rhs: index_type): + return arith.addi(lhs, rhs) + + # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_3:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_5:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_6:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_7:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_8:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_9:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_10:.*]]:3 = scf.parallel (%[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]]) = (%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) to (%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) step (%[[VAL_7]], %[[VAL_8]], %[[VAL_9]]) init (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) -> (index, index, index) { + # CHECK: scf.reduce(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : index, index, index) { + # CHECK: ^bb0(%[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index): + # CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index + # CHECK: scf.reduce.return %[[VAL_16]] : index + # CHECK: }, { + # CHECK: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index): + # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index + # CHECK: scf.reduce.return %[[VAL_19]] : index + # CHECK: }, { + # CHECK: ^bb0(%[[VAL_20:.*]]: index, %[[VAL_21:.*]]: index): + # CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index + # CHECK: scf.reduce.return %[[VAL_22]] : index + # CHECK: } + # CHECK: } + + # CHECK-LABEL: TEST: testSimpleLoop @constructAndPrintInModule def testSimpleLoop(): diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py index c2d51083c1379..4a01ed2ec6a61 100644 --- a/mlir/test/python/ir/auto_location.py +++ b/mlir/test/python/ir/auto_location.py @@ -39,7 +39,7 @@ def testInferLocations(): print(op.location) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":47:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":92:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) # fmt: on print(one.location)