Skip to content

[mlir][python] add scf.parallel/scf.forall helpers #150243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Jul 23, 2025

This PR upstreams scf.parallel/scf.forall helpers from mlir-python-extras. Based on #149416.

This PR also factors out some common utility functions into a top-level util.py. Note, I did the refactor here because the utility functions are used by the upstreamed helpers.

@makslevental makslevental changed the title Makslevental/scf forall parallel [mlir][python] add scf.parallel/scf.forall helpers Jul 23, 2025
@makslevental makslevental marked this pull request as ready for review August 8, 2025 00:13
@llvmbot llvmbot added mlir:linalg mlir:python MLIR Python bindings mlir labels Aug 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

WIP

cc @Cubevoid


Patch is 29.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150243.diff

10 Files Affected:

  • (modified) mlir/python/CMakeLists.txt (+1)
  • (modified) mlir/python/mlir/dialects/arith.py (+20-22)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+38-68)
  • (modified) mlir/python/mlir/dialects/memref.py (+3-2)
  • (modified) mlir/python/mlir/dialects/scf.py (+135-2)
  • (modified) mlir/python/mlir/dialects/tensor.py (+42)
  • (modified) mlir/python/mlir/extras/meta.py (+15-1)
  • (added) mlir/python/mlir/util.py (+51)
  • (modified) mlir/test/python/dialects/arith_dialect.py (+3-4)
  • (modified) mlir/test/python/dialects/scf.py (+151-1)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..f01798f48ff86 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    util.py
 
     # The main _mlir module has submodules: include stubs from each.
     _mlir_libs/_mlir/__init__.pyi
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce66..418ec0bdbb6db 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,9 +5,11 @@
 from ._arith_ops_gen import *
 from ._arith_ops_gen import _Dialect
 from ._arith_enum_gen import *
+from ..util import is_integer_type, is_index_type, is_float_type
 from array import array as _array
 from typing import overload
 
+
 try:
     from ..ir import *
     from ._ods_common import (
@@ -21,26 +23,6 @@
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-def _isa(obj: Any, cls: type):
-    try:
-        cls(obj)
-    except ValueError:
-        return False
-    return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
-    return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
-    return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
-    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
@@ -96,9 +78,9 @@ def value(self):
 
     @property
     def literal_value(self) -> Union[int, float]:
-        if _is_integer_like_type(self.type):
+        if is_integer_type(self.type) or is_index_type(self.type):
             return IntegerAttr(self.value).value
-        elif _is_float_type(self.type):
+        elif is_float_type(self.type):
             return FloatAttr(self.value).value
         else:
             raise ValueError("only integer and float constants have literal values")
@@ -108,3 +90,19 @@ def constant(
     result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+def index_cast(
+    in_: Value,
+    to: Type = None,
+    *,
+    out: Type = None,
+    loc: Location = None,
+    ip: InsertionPoint = None,
+) -> Value:
+    if bool(to) != bool(out):
+        raise ValueError("either `to` or `out` must be set but not both")
+    res_type = out or to
+    if res_type is None:
+        res_type = IndexType.get()
+    return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..cae70fc03b9d6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -5,6 +5,13 @@
 from typing import Callable, Dict, List, Sequence, Tuple, Union
 
 from .....ir import *
+from .....util import (
+    is_complex_type,
+    is_float_type,
+    is_index_type,
+    is_integer_type,
+    get_floating_point_width,
+)
 
 from .... import func
 from .... import linalg
@@ -412,9 +419,9 @@ def _cast(
             )
         if operand.type == to_type:
             return operand
-        if _is_integer_type(to_type):
+        if is_integer_type(to_type):
             return self._cast_to_integer(to_type, operand, is_unsigned_cast)
-        elif _is_floating_point_type(to_type):
+        elif is_float_type(to_type):
             return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
 
     def _cast_to_integer(
@@ -422,11 +429,11 @@ def _cast_to_integer(
     ) -> Value:
         to_width = IntegerType(to_type).width
         operand_type = operand.type
-        if _is_floating_point_type(operand_type):
+        if is_float_type(operand_type):
             if is_unsigned_cast:
                 return arith.FPToUIOp(to_type, operand).result
             return arith.FPToSIOp(to_type, operand).result
-        if _is_index_type(operand_type):
+        if is_index_type(operand_type):
             return arith.IndexCastOp(to_type, operand).result
         # Assume integer.
         from_width = IntegerType(operand_type).width
@@ -444,13 +451,13 @@ def _cast_to_floating_point(
         self, to_type: Type, operand: Value, is_unsigned_cast: bool
     ) -> Value:
         operand_type = operand.type
-        if _is_integer_type(operand_type):
+        if is_integer_type(operand_type):
             if is_unsigned_cast:
                 return arith.UIToFPOp(to_type, operand).result
             return arith.SIToFPOp(to_type, operand).result
         # Assume FloatType.
-        to_width = _get_floating_point_width(to_type)
-        from_width = _get_floating_point_width(operand_type)
+        to_width = get_floating_point_width(to_type)
+        from_width = get_floating_point_width(operand_type)
         if to_width > from_width:
             return arith.ExtFOp(to_type, operand).result
         elif to_width < from_width:
@@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
         return self._cast(type_var_name, operand, True)
 
     def _unary_exp(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.AbsFOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.CeilOp(x).result
         raise NotImplementedError("Unsupported 'ceil' operand: {x}")
 
     def _unary_floor(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.FloorOp(x).result
         raise NotImplementedError("Unsupported 'floor' operand: {x}")
 
     def _unary_negf(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return arith.NegFOp(x).result
-        if _is_complex_type(x.type):
+        if is_complex_type(x.type):
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.AddIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.AddOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
 
     def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.SubFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.SubIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.SubOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
     def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MulFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MulIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.MulOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
     def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MaximumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MaxSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
 
     def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MaximumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MaxUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
     def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MinimumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MinSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
 
     def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MinimumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MinUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
@@ -609,40 +616,3 @@ def _add_type_mapping(
             )
     type_mapping[name] = element_or_self_type
     block_arg_types.append(element_or_self_type)
-
-
-def _is_complex_type(t: Type) -> bool:
-    return ComplexType.isinstance(t)
-
-
-def _is_floating_point_type(t: Type) -> bool:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    return (
-        F64Type.isinstance(t)
-        or F32Type.isinstance(t)
-        or F16Type.isinstance(t)
-        or BF16Type.isinstance(t)
-    )
-
-
-def _is_integer_type(t: Type) -> bool:
-    return IntegerType.isinstance(t)
-
-
-def _is_index_type(t: Type) -> bool:
-    return IndexType.isinstance(t)
-
-
-def _get_floating_point_width(t: Type) -> int:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    if F64Type.isinstance(t):
-        return 64
-    if F32Type.isinstance(t):
-        return 32
-    if F16Type.isinstance(t):
-        return 16
-    if BF16Type.isinstance(t):
-        return 16
-    raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..2130a1966d88e 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -7,8 +7,9 @@
 
 from ._memref_ops_gen import *
 from ._ods_common import _dispatch_mixed_values, MixedValues
-from .arith import ConstantOp, _is_integer_like_type
+from .arith import ConstantOp
 from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
+from ..util import is_integer_like_type
 
 
 def _is_constant_int_like(i):
@@ -16,7 +17,7 @@ def _is_constant_int_like(i):
         isinstance(i, Value)
         and isinstance(i.owner, Operation)
         and isinstance(i.owner.opview, ConstantOp)
-        and _is_integer_like_type(i.type)
+        and is_integer_like_type(i.type)
     )
 
 
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..3b58d5c1c48b6 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -5,10 +5,12 @@
 
 from ._scf_ops_gen import *
 from ._scf_ops_gen import _Dialect
-from .arith import constant
+from . import arith
+from ..extras.meta import region_op, region_adder
 
 try:
     from ..ir import *
+    from ..util import is_index_type
     from ._ods_common import (
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
@@ -237,7 +239,7 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(IndexType.get(), p)
+            p = arith.constant(IndexType.get(), p)
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
@@ -254,3 +256,134 @@ def for_(
             yield iv, iter_args[0], for_op.results[0]
         else:
             yield iv
+
+
+def _parfor(op_ctor):
+    def _base(
+        lower_bounds, upper_bounds=None, steps=None, *, loc=None, ip=None, **kwargs
+    ):
+        if upper_bounds is None:
+            upper_bounds = lower_bounds
+            lower_bounds = [0] * len(upper_bounds)
+        if steps is None:
+            steps = [1] * len(lower_bounds)
+
+        params = [lower_bounds, upper_bounds, steps]
+        for i, p in enumerate(params):
+            for j, pp in enumerate(p):
+                if isinstance(pp, int):
+                    pp = arith.constant(IndexType.get(), pp)
+                assert isinstance(pp, Value), f"expected ir.Value, got {type(pp)=}"
+                if not is_index_type(pp.type):
+                    pp = arith.index_cast(pp)
+                p[j] = pp
+            params[i] = p
+
+        return op_ctor(*params, loc=loc, ip=ip, **kwargs)
+
+    return _base
+
+
+def _parfor_cm(op_ctor):
+    def _base(*args, **kwargs):
+        for_op = _parfor(op_ctor)(*args, **kwargs)
+        block = for_op.regions[0].blocks[0]
+        block_args = tuple(block.arguments)
+        with InsertionPoint(block):
+            yield block_args
+
+    return _base
+
+
+forall = _parfor_cm(ForallOp)
+
+
+class ParallelOp(ParallelOp):
+    def __init__(
+        self,
+        lower_bounds,
+        upper_bounds,
+        steps,
+        inits: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        assert len(lower_bounds) == len(upper_bounds) == len(steps)
+        if inits is None:
+            inits = []
+        results = [i.type for i in inits]
+        iv_types = [IndexType.get()] * len(lower_bounds)
+        super().__init__(
+            results,
+            lower_bounds,
+            upper_bounds,
+            steps,
+            inits,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*iv_types)
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variables(self):
+        return self.body.arguments
+
+
+parallel = _parfor_cm(ParallelOp)
+
+
+class ReduceOp(ReduceOp):
+    def __init__(self, operands, num_reductions, *, loc=None, ip=None):
+        super().__init__(operands, num_reductions, loc=loc, ip=ip)
+        for i in range(num_reductions):
+            self.regions[i].blocks.append(operands[i].type, operands[i].type)
+
+
+def reduce_(*operands, num_reductions=1, loc=None, ip=None):
+    return ReduceOp(operands, num_reductions, loc=loc, ip=ip)
+
+
+reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))
+
+
+@region_adder(terminator=lambda xs: reduce_return(*xs))
+def another_reduce(reduce_op):
+    for r in reduce_op.regions:
+        if len(r.blocks[0].operations) == 0:
+            return r
+
+
+@region_op
+def in_parallel():
+    return InParallelOp()
+
+
+def parallel_insert_slice(
+    source,
+    dest,
+    static_offsets=None,
+    static_sizes=None,
+    static_strides=None,
+    offsets=None,
+    sizes=None,
+    strides=None,
+):
+    from . import tensor
+
+    @in_parallel
+    def foo():
+        tensor.parallel_insert_slice(
+            source,
+            dest,
+            offsets,
+            sizes,
+            strides,
+            static_offsets,
+            static_sizes,
+            static_strides,
+        )
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 146b5f85d07f5..b1baa22b15e23 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -65,3 +65,45 @@ def empty(
     lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
     terminator=lambda args: YieldOp(args[0]),
 )
+
+
+def parallel_insert_slice(
+    source,
+    dest,
+    offsets=None,
+    sizes=None,
+    strides=None,
+    static_offsets=None,
+    static_sizes=None,
+    static_strides=None,
+):
+    S = ShapedType.get_dynamic_size()
+    if static_offsets is None:
+        assert offsets is not None
+        static_offsets = [S, S]
+    if static_sizes is None:
+        assert sizes is not None
+        static_sizes = [S, S]
+    if static_strides is None:
+        assert strides is not None
+        static_strides = [S, S]
+    if offsets is None:
+        assert static_offsets
+        offsets = []
+    if sizes is None:
+        assert static_sizes
+        sizes = []
+    if strides is None:
+        assert static_strides
+        strides = []
+
+    return ParallelInsertSliceOp(
+        source,
+        dest,
+        offsets,
+        sizes,
+        strides,
+        static_offsets,
+        static_sizes,
+        static_strides,
+    )
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
index 3f2defadf7941..fabe1d8e141ed 100644
--- a/mlir/python/mlir/extras/meta.py
+++ b/mlir/python/mlir/extras/meta.py
@@ -6,7 +6,7 @@
 from functools import wraps
 
 from ..dialects._ods_common import get_op_result_or_op_results
-from ..ir import Type, InsertionPoint
+from ..ir import Type, InsertionPoint, Value
 
 
 def op_region_builder(op, op_region, terminator=None):
@@ -81,3 +81,17 @@ def maybe_no_args(*args, **kwargs):
             return op_decorator(*args, **kwargs)
 
     return maybe_no_args
+
+
+def region_adder(terminator=None):
+    def wrapper(op_region_adder):
+        def region_adder_decorator(op, *args, **kwargs):
+            if isinstance(op, Value):
+                op = op.owner.opview
+            region = op_region_adder(op, *args, **kwargs)
+
+            return op_region_builder(op, region, terminator)
+
+        return region_adder_decorator
+
+    return wrapper
diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py
new file mode 100644
index 0000000000000..453b74777014f
--- /dev/null
+++ b/mlir/python/mlir/util.py
@@ -0,0 +1,51 @@
+from .ir import (
+    BF16Type,
+    ComplexType,
+    F16Type,
+    F32Type,
+    F64Type,
+    IndexType,
+    IntegerType,
+    Type,
+)
+
+
+def is_complex_type(t: Type) -> bool:
+    return ComplexType.isinstance(t)
+
+
+def is_float_type(t: Type) -> bool:
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    return (
+        F64Type.isinstance(t)
+        or F32Type.isinstance(t)
+        or F16Type.isinstance(t)
+        or BF16Type.isinstance(t)
+    )
+
+
+def is_integer_type(t: Type) -> bool:
+    return IntegerType.isinstance(t)
+
+
+def is_index_type(t: Type) -> bool:
+    return IndexType.isinstance(t)
+
+
+def is_integer_like_type(t: Type) -> bool:
+    return is_integer_type(t) or is_index_type(t)
+
+
+def get_floating_point_width(t: Type) -> int:
+    # TODO: Create a FloatType in the Python API and implement the switch
+  ...
[truncated]

@makslevental
Copy link
Contributor Author

cc @Cubevoid - feel free to review if you'd like.

@makslevental makslevental requested a review from jpienaar August 8, 2025 05:12
def is_float_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
return (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit: would isinstance(t, (F64Type, F32Type, F16Type, BF16Type)) work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type.isinstance predates https://reviews.llvm.org/D150927 (i.e. when t: Type was "abstract" and didn't have the correct matching Python type). So it'll work (I'm %99.99999 sure but I'll check) but I'd rather leave as is because Type.isinstance is 100%.

def res1(lhs: index_type, rhs: index_type):
return arith.addi(lhs, rhs)

@scf.another_reduce(res1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has me thinking if this is the best API to expose here: I needed to go back up the PR to understand that this is a new helper in the scf namespace that does not have a counterpart in the dialect. It is just there to tie the multiple regions together.

I understand (nested) functions are the natural way to represent (single-block) regions. I am just thinking if we want to have these ad-hoc methods for tacking regions together or if there's a principled way/API that we could use to represent that ops can have multiple regions (just focused on single-block regions for now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or if there's a principled way/API that we could use to represent that ops can have multiple regions

I have pondered this conundrum for years now (well just like ~2 years) and the only other way to do this is using something like

with another_reduce(res1) as lhs, rhs:
    ...

which is inferior because

  1. lhs, rhs will "leak" from the indented block (context managers don't have scope in Python)
  2. you cannot annotate lhs: index_type, rhs: index_type so you actually have to pass the types to with another_reduce(res1, index_type, index_type), which is pretty ugly IMHO

scf namespace

I know it's nice to treat the python modules as synonymous with MLIR dialect namespace but they're not really right? Plenty of helpers and etc are in lots of these files. On the otherhand, if you we did make scf a package (i.e. put it into a folder with __init__.py) then we could have a module in there called something like helpers (and all the ops would be just top-level in the __init__.py). It probably would've been prudent/diligent to do that from the beginning but today, given the rest of the dialects won't be organized this way (at least not initially), I don't think it's worth it - just keep the status quo understanding that the Python modules aren't synonymous with the MLIR dialect namespaces.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear ya. At the same time I feel one of the bigger pain points for newcomers to the Python bindings is the lack of a lack of surprise. E.g. that the namespaces in Python are not always a direct reflection of the same-name namespaces as dialects (or that the snake-case ops are not always available, etc.).

I agree with your analysis that functions are preferable over context managers. I hope (perhaps naively so) that can still find a pattern that we can consistently apply across the bindings code base, especially for new code. The following still has a high degree of surprise, though it is something we could apply everywhere:

How about making regions usable as decorators? That is, have Region implement __call__ (which is rather surprising 😢 ) though it would then enable the following:

red = scf.reduce(i, j, num_regions=2)
@red.regions[0]
def reg0(lhs: index_type, rhs: index_type):
  return arith.addi(lhs, rhs)

@red.regions[1]
def reg0(lhs: index_type, rhs: index_type):
  return arith.addi(lhs, rhs)

I think such a thing would also allow us to forgo having the region_op decorator for, e.g., linalg.generic. So instead of writing

    @linalg.generic([A, B], [C], affine_maps, iterator_types)
    def f(a, b, c):
        ...

we could write

    @linalg.generic([A, B], [C], affine_maps, iterator_types).region
    def f(a, b, c):
        ...

which to me is less surprising and means there are fewer special cases/ops to be mindful off. (Though I am not sure it works out here if a, b and c are not typed - hmm.)

Maybe @nicolasvasilache could chime in regarding making the bindings more consistent.

Copy link
Contributor Author

@makslevental makslevental Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making regions usable as decorators? That is, have Region implement call (which is rather surprising 😢 ) though it would then enable the following:

You can't add that to PyRegion because it models mlir::Region which can stand alone. You could still pull off this trick by changing what op.region returns but that would be a breaking change. Also the index is superfluous because you can't add regions in arbitrary order (only append). So your region accessor would just be the same region_op but as a method rather than free function (though I guess it might return the region instead of the op?).

bigger pain points for newcomers to the Python bindings is the lack of a lack of surprise

I don't disagree with you - I think about that absolutely all the time wrt to many of the disparate features/patterns in the bindings. Unfortunately it's just an artifact of the fact that they weren't meant to be used for a DSL (they were meant to be used only to write tests). That and a lack of consistent ownership/development/plan. So I would like for this to be better but at this point we need to rewrite large parts to achieve the kind of consistency you'd like (we'd like). Spoiler alert: I plan to propose exactly such a rewrite in the coming months. But until/if then, the only thing I believe we can do is smooth over things like this downstream. Ie the bindings function as a toolkit to build a DSL but are not the DSL itself. Prime example JAX (modulo their whole stacked interpreter thing). I know that's not a great answer because it

  1. seems like it shirks responsibility upstream (not my intent)
  2. requires lots of boiler-platey code downstream

Vis-a-vis the prospective rewrite I intend to propose adding tooling to enable extending tablegen for the bindings downstream (so one would be able to generate a backwards breaking .region for their ops if they so chose).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could still pull off this trick by changing what op.region returns but that would be a breaking change

Actually the current accessor is regions not region so this wouldn't actually be a breaking change. But I would prefer the name still be something like add_region so that people don't expect op.region to return a region in the single-region case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants