diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index bf54efee1f14e..bc2e676a878c0 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -1017,90 +1017,79 @@ very generic signature. #### Extending Generated Op Classes -Note that this is a rather complex mechanism and this section errs on the side -of explicitness. Users are encouraged to find an example and duplicate it if -they don't feel the need to understand the subtlety. The `builtin` dialect -provides some relatively simple examples. - As mentioned above, the build system generates Python sources like `_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is -often desirable to to use these generated classes as a starting point for -further customization, so an extension mechanism is provided to make this easy -(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file -but we prefer a more standard mechanism that is applied uniformly). +often desirable to use these generated classes as a starting point for +further customization, so an extension mechanism is provided to make this easy. +This mechanism uses conventional inheritance combined with `OpView` registration. +For example, the default builder for `arith.constant` + +```python +class ConstantOp(_ods_ir.OpView): + OPERATION_NAME = "arith.constant" + + _ODS_REGIONS = (0, True) + + def __init__(self, value, *, loc=None, ip=None): + ... +``` -To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the -`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and -the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an -example, the generated code will include an import like this: +expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`). +Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`: ```python -try: - from . import _builtin_ops_ext as _ods_ext_module -except ImportError: - _ods_ext_module = None +from typing import Union + +from mlir.ir import Type, IntegerAttr, FloatAttr +from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp +from mlir.dialects._ods_common import _cext + +@_cext.register_operation(_Dialect, replace=True) +class ConstantOpExt(ConstantOp): + def __init__( + self, result: Type, value: Union[int, float], *, loc=None, ip=None + ): + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + else: + raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}") ``` -Then for each generated concrete `OpView` subclass, it will apply a decorator -like: +which enables building an instance of `arith.constant` like so: ```python -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class FuncOp(_ods_ir.OpView): +from mlir.ir import F32Type + +a = ConstantOpExt(F32Type.get(), 42.42) +b = ConstantOpExt(IntegerType.get_signless(32), 42) ``` -See the `_ods_common.py` `extend_opview_class` function for details of the -mechanism. At a high level: - -* If the extension module exists, locate an extension class for the op (in - this example, `FuncOp`): - * First by looking for an attribute with the exact name in the extension - module. - * Falling back to calling a `select_opview_mixin(parent_opview_cls)` - function defined in the extension module. -* If a mixin class is found, a new subclass is dynamically created that - multiply inherits from `({_builtin_ops_ext.FuncOp}, - _builtin_ops_gen.FuncOp)`. - -The mixin class should not inherit from anything (i.e. directly extends `object` -only). The facility is typically used to define custom `__init__` methods, -properties, instance methods and static methods. Due to the inheritance -ordering, the mixin class can act as though it extends the generated `OpView` -subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)` -will return `False` but usage generally allows you treat it as duck typed as an -`OpView`). - -There are a couple of recommendations, given how the class hierarchy is defined: - -* For static methods that need to instantiate the actual "leaf" op (which is - dynamically generated and would result in circular dependencies to try to - reference by name), prefer to use `@classmethod` and the concrete subclass - will be provided as your first `cls` argument. See - `_builtin_ops_ext.FuncOp.from_py_func` as an example. -* If seeking to replace the generated `__init__` method entirely, you may - actually want to invoke the super-super-class `mlir.ir.OpView` constructor - directly, as it takes an `mlir.ir.Operation`, which is likely what you are - constructing (i.e. the generated `__init__` method likely adds more API - constraints than you want to expose in a custom builder). - -A pattern that comes up frequently is wanting to provide a sugared `__init__` -method which has optional or type-polymorphism/implicit conversions but to -otherwise want to invoke the default op building logic. For such cases, it is -recommended to use an idiom such as: +Note, three key aspects of the extension mechanism in this example: + +1. `ConstantOpExt` directly inherits from the generated `ConstantOp`; +2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`; +3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks)) + we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**. + +In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders. +I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder. +Thus, we must call a method of a super class' super class (the "grandparent"); for example: ```python - def __init__(self, sugar, spice, *, loc=None, ip=None): - ... massage into result_type, operands, attributes ... - OpView.__init__(self, self.build_generic( - results=[result_type], - operands=operands, - attributes=attributes, - loc=loc, - ip=ip)) +from mlir.dialects._scf_ops_gen import _Dialect, ForOp +from mlir.dialects._ods_common import _cext + +@_cext.register_operation(_Dialect, replace=True) +class ForOpExt(ForOp): + def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None): + ... + super(ForOp, self).__init__(self.build_generic(...)) ``` -Refer to the documentation for `build_generic` for more information. +where `OpView.__init__` is called via `super(ForOp, self).__init__`. +Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance. ## Providing Python bindings for a dialect diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 97cd70089a2e9..21899bdce22e8 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -77,10 +77,10 @@ class PyGlobals { pybind11::object pyClass); /// Adds a concrete implementation operation class. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass); + pybind11::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. std::optional diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 2cc66277abee0..a1c8ab7a09ce1 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass) { + py::object pyClass, bool replace) { py::object &found = operationClassMap[operationName]; - if (found) { + if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") .str()); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index cdddfbe50606d..a936becf67bea 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, + "operation_name"_a, "operation_class"_a, "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::object &dialectClass) -> py::cpp_function { + [](const py::object &dialectClass, bool replace) -> py::cpp_function { return py::cpp_function( - [dialectClass](py::object opClass) -> py::object { + [dialectClass, replace](py::object opClass) -> py::object { std::string operationName = opClass.attr("OPERATION_NAME").cast(); - PyGlobals::get().registerOperationImpl(operationName, opClass); + PyGlobals::get().registerOperationImpl(operationName, opClass, + replace); // Dict-stuff the new opClass by name onto the dialect class. py::object opClassName = opClass.attr("__name__"); @@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "dialect_class"_a, + "dialect_class"_a, "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c7b3c283a6b6d..88e6e13602d29 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/AffineOps.td SOURCES dialects/affine.py - dialects/_affine_ops_ext.py DIALECT_NAME affine GEN_ENUM_BINDINGS) @@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/BufferizationOps.td SOURCES dialects/bufferization.py - dialects/_bufferization_ops_ext.py DIALECT_NAME bufferization GEN_ENUM_BINDINGS_TD_FILE "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td" @@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/BuiltinOps.td SOURCES dialects/builtin.py - dialects/_builtin_ops_ext.py DIALECT_NAME builtin) declare_mlir_dialect_python_bindings( @@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/FuncOps.td SOURCES dialects/func.py - dialects/_func_ops_ext.py DIALECT_NAME func) declare_mlir_dialect_python_bindings( @@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/LinalgOps.td SOURCES - dialects/_linalg_ops_ext.py SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg @@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TransformPDLExtensionOps.td SOURCES - dialects/_transform_pdl_extension_ops_ext.py dialects/transform/pdl.py DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) @@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TransformOps.td SOURCES - dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform @@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/BufferizationTransformOps.td SOURCES - dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform EXTENSION_NAME bufferization_transform) @@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUTransformOps.td SOURCES - dialects/_gpu_transform_ops_ext.py dialects/transform/gpu.py DIALECT_NAME transform EXTENSION_NAME gpu_transform) @@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SCFLoopTransformOps.td SOURCES - dialects/_loop_transform_ops_ext.py dialects/transform/loop.py DIALECT_NAME transform EXTENSION_NAME loop_transform) @@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MemRefTransformOps.td SOURCES - dialects/_memref_transform_ops_ext.py dialects/transform/memref.py DIALECT_NAME transform EXTENSION_NAME memref_transform) @@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/LinalgStructuredTransformOps.td SOURCES - dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform EXTENSION_NAME structured_transform @@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorTransformOps.td SOURCES - dialects/_tensor_transform_ops_ext.py dialects/transform/tensor.py DIALECT_NAME transform EXTENSION_NAME tensor_transform) @@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/ArithOps.td SOURCES dialects/arith.py - dialects/_arith_ops_ext.py DIALECT_NAME arith GEN_ENUM_BINDINGS) @@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/MemRefOps.td SOURCES dialects/memref.py - dialects/_memref_ops_ext.py DIALECT_NAME memref) declare_mlir_dialect_python_bindings( @@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/MLProgramOps.td SOURCES dialects/ml_program.py - dialects/_ml_program_ops_ext.py DIALECT_NAME ml_program) declare_mlir_dialect_python_bindings( @@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/PDLOps.td SOURCES dialects/pdl.py - dialects/_pdl_ops_ext.py _mlir_libs/_mlir/dialects/pdl.pyi DIALECT_NAME pdl) @@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/SCFOps.td SOURCES dialects/scf.py - dialects/_scf_ops_ext.py DIALECT_NAME scf) declare_mlir_dialect_python_bindings( @@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/TensorOps.td SOURCES dialects/tensor.py - dialects/_tensor_ops_ext.py DIALECT_NAME tensor) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py deleted file mode 100644 index dc465ce7aa1e5..0000000000000 --- a/mlir/python/mlir/dialects/_affine_ops_ext.py +++ /dev/null @@ -1,56 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class AffineStoreOp: - """Specialization for the Affine store operation.""" - - def __init__( - self, - value: Union[Operation, OpView, Value], - memref: Union[Operation, OpView, Value], - map: AffineMap=None, - *, - map_operands=None, - loc=None, - ip=None - ): - """Creates an affine store operation. - - - `value`: the value to store into the memref. - - `memref`: the buffer to store into. - - `map`: the affine map that maps the map_operands to the index of the - memref. - - `map_operands`: the list of arguments to substitute the dimensions, - then symbols in the affine map, in increasing order. - """ - map = map if map is not None else [] - map_operands = map_operands if map_operands is not None else [] - operands = [ - _get_op_result_or_value(value), - _get_op_result_or_value(memref), - *[_get_op_result_or_value(op) for op in map_operands] - ] - results = [] - attributes = {"map": AffineMapAttr.get(map)} - regions = None - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, - results=results, - operands=operands, - successors=_ods_successors, - regions=regions, - loc=loc, - ip=ip - )) diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py deleted file mode 100644 index df38f871710fe..0000000000000 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ /dev/null @@ -1,69 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context - - from typing import Any, List, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True - - -def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) - - -def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) - - -def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) - - -class ConstantOp: - """Specialization for the constant op class.""" - - def __init__( - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None - ): - if isinstance(value, int): - super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, float): - super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) - else: - super().__init__(value, loc=loc, ip=ip) - - @classmethod - def create_index(cls, value: int, *, loc=None, ip=None): - """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip - ) - - @property - def type(self): - return self.results[0].type - - @property - def value(self): - return Attribute(self.operation.attributes["value"]) - - @property - def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): - return IntegerAttr(self.value).value - elif _is_float_type(self.type): - return FloatAttr(self.value).value - else: - raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py deleted file mode 100644 index 1066cb4c775ca..0000000000000 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ /dev/null @@ -1,41 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - - from typing import Any, List, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -class AllocTensorOp: - """Extends the bufferization.alloc_tensor op.""" - - def __init__( - self, - tensor_type: Type, - dynamic_sizes: Sequence[Value], - copy: Value, - size_hint: Value, - escape: BoolAttr, - *, - loc=None, - ip=None - ): - """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" - context = get_default_loc_context(loc) - attributes = {} - if escape: - attributes["escape"] = escape - op = self.build_generic( - results=[tensor_type], - operands=[dynamic_sizes, copy, size_hint], - attributes=attributes, - loc=loc, - ip=ip, - ) - OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py deleted file mode 100644 index 7e6c1b81cb350..0000000000000 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ /dev/null @@ -1,128 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from enum import Enum -from typing import Optional, overload, Union - - -class EmptyTensorToAllocTensorOp: - """Specialization for EmptyTensorToAllocTensorOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.OperationType.get("bufferization.alloc_tensor") - target = transformed_type_or_target - - super().__init__( - transformed_type, - target, - loc=loc, - ip=ip, - ) - - -class OneShotBufferizeOp: - """Specialization for OneShotBufferizeOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - - super().__init__( - transformed_type, - target, - allow_return_allocs_from_loops=allow_return_allocs_from_loops, - allow_unknown_ops=allow_unknown_ops, - bufferize_function_boundaries=bufferize_function_boundaries, - function_boundary_type_conversion=function_boundary_type_conversion, - memcpy_op=memcpy_op, - print_conflicts=print_conflicts, - test_analysis_only=test_analysis_only, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py deleted file mode 100644 index 27a60123050ac..0000000000000 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ /dev/null @@ -1,20 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -class ModuleOp: - """Specialization for the module op class.""" - - def __init__(self, *, loc=None, ip=None): - super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) - body = self.regions[0].blocks.append() - - @property - def body(self): - return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py deleted file mode 100644 index 6d264c33f1f9d..0000000000000 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ /dev/null @@ -1,319 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context - - import inspect - - from typing import Any, List, Optional, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" -RESULT_ATTRIBUTE_NAME = "res_attrs" - - -class ConstantOp: - """Specialization for the constant op class.""" - - def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) - - @property - def type(self): - return self.results[0].type - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__( - self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None - ): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = ( - StringAttr.get(str(visibility)) if visibility is not None else None - ) - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError("External function does not have a body") - return self.regions[0].blocks[0] - - def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError("The function already has an entry block!") - self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context - ) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func( - FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None, - ): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and ( - param.kind == param.POSITIONAL_OR_KEYWORD - or param.kind == param.KEYWORD_ONLY - ): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results - ) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None." - ) - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get( - inputs=inputs, results=return_types - ) - func_op.attributes["function_type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp( - return_types, FlatSymbolRefAttr.get(symbol_name), call_args - ) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result - else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped - - return decorator - - -class CallOp: - """Specialization for the call op class.""" - - def __init__( - self, - calleeOrResults: Union[FuncOp, List[Type]], - argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], - arguments: Optional[List] = None, - *, - loc=None, - ip=None, - ): - """Creates an call operation. - - The constructor accepts three different forms: - - 1. A function op to be called followed by a list of arguments. - 2. A list of result types, followed by the name of the function to be - called as string, following by a list of arguments. - 3. A list of result types, followed by the name of the function to be - called as symbol reference attribute, followed by a list of arguments. - - For example - - f = func.FuncOp("foo", ...) - func.CallOp(f, [args]) - func.CallOp([result_types], "foo", [args]) - - In all cases, the location and insertion point may be specified as keyword - arguments if not provided by the surrounding context managers. - """ - - # TODO: consider supporting constructor "overloads", e.g., through a custom - # or pybind-provided metaclass. - if isinstance(calleeOrResults, FuncOp): - if not isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function, expected " - + "the second argument to be a list of call arguments, " - + f"got {type(argumentsOrCallee)}" - ) - if arguments is not None: - raise ValueError( - "unexpected third argument when constructing a call" - + "to a function" - ) - - super().__init__( - calleeOrResults.type.results, - FlatSymbolRefAttr.get( - calleeOrResults.name.value, context=_get_default_loc_context(loc) - ), - argumentsOrCallee, - loc=loc, - ip=ip, - ) - return - - if isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function by name, " - + "expected the second argument to be a string or a " - + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" - ) - - if isinstance(argumentsOrCallee, FlatSymbolRefAttr): - super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip - ) - elif isinstance(argumentsOrCallee, str): - super().__init__( - calleeOrResults, - FlatSymbolRefAttr.get( - argumentsOrCallee, context=_get_default_loc_context(loc) - ), - arguments, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py deleted file mode 100644 index ba72bac3a1526..0000000000000 --- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py +++ /dev/null @@ -1,124 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union, overload - - -class MapForallToBlocks: - """Specialization for MapForallToBlocks class.""" - - @overload - def __init__( - self, - result_type: Type, - target: Union[Operation, OpView, Value], - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - result_type_or_target: Union[Operation, OpView, Type, Value], - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_none - else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - - super().__init__( - result_type, - target, - grid_dims=grid_dims, - generate_gpu_launch=generate_gpu_launch, - loc=loc, - ip=ip, - ) - - -class MapNestedForallToThreads: - """Specialization for MapNestedForallToThreads class.""" - - @overload - def __init__( - self, - result_type: Type, - target: Union[Operation, OpView, Value], - *, - block_dims: Optional[Sequence[int]] = None, - warp_size: Optional[Sequence[int]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - block_dims: Optional[Sequence[int]] = None, - warp_size: Optional[Sequence[int]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - result_type_or_target: Union[Operation, OpView, Value, Type], - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - block_dims: Optional[Union[Sequence[int], Attribute]] = None, - warp_size: Optional[Union[Sequence[int], Attribute]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_none - else: - result_type = result_type_or_target.type - target = result_type_or_target - super().__init__( - result_type, - target, - block_dims=block_dims, - warp_size=warp_size, - sync_after_distribute=sync_after_distribute, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py deleted file mode 100644 index 3f6d854ca3e2b..0000000000000 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ /dev/null @@ -1,47 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Optional, Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from ._ods_common import get_op_result_or_value as _get_op_result_or_value - - -def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False - - -class StructuredOpMixin: - """All structured ops use the same mixin class.""" - - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - super().__init__( - self.build_generic( - results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip, - ) - ) - - -def select_opview_mixin(parent_opview_cls): - # TODO: This shouldn't be a heuristic: we should have a way to annotate - # the OpView to note that it is a structured op. - if ( - "__init__" not in parent_opview_cls.__dict__ - and hasattr(parent_opview_cls, "inputs") - and hasattr(parent_opview_cls, "outputs") - and hasattr(parent_opview_cls, "result_tensors") - ): - return StructuredOpMixin diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py deleted file mode 100644 index 1cdb2b9e77b5a..0000000000000 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ /dev/null @@ -1,134 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Union - - -class GetParentForOp: - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) - - -class LoopOutlineOp: - """Extension for LoopOutlineOp.""" - - def __init__( - self, - function_type: Type, - call_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None, - ): - super().__init__( - function_type, - call_type, - _get_op_result_or_value(target), - func_name=( - func_name - if isinstance(func_name, StringAttr) - else StringAttr.get(func_name) - ), - ip=ip, - loc=loc, - ) - - -class LoopPeelOp: - """Extension for LoopPeelOp.""" - - def __init__( - self, - main_loop_type: Type, - remainder_loop_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None, - ): - super().__init__( - main_loop_type, - remainder_loop_type, - _get_op_result_or_value(target), - fail_if_already_divisible=( - fail_if_already_divisible - if isinstance(fail_if_already_divisible, BoolAttr) - else BoolAttr.get(fail_if_already_divisible) - ), - ip=ip, - loc=loc, - ) - - -class LoopPipelineOp: - """Extension for LoopPipelineOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None, - ): - if iteration_interval is None: - iteration_interval = 1 - if read_latency is None: - read_latency = 10 - super().__init__( - result_type, - _get_op_result_or_value(target), - iteration_interval=iteration_interval, - read_latency=read_latency, - ip=ip, - loc=loc, - ) - - -class LoopUnrollOp: - """Extension for LoopUnrollOp.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None, - ): - super().__init__( - _get_op_result_or_value(target), - factor=factor, - ip=ip, - loc=loc, - ) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py deleted file mode 100644 index 825f1a0a7a6fa..0000000000000 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ /dev/null @@ -1,36 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class LoadOp: - """Specialization for the MemRef load operation.""" - - def __init__( - self, - memref: Union[Operation, OpView, Value], - indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, - *, - loc=None, - ip=None - ): - """Creates a memref load operation. - - Args: - memref: the buffer to load from. - indices: the list of subscripts, may be empty for zero-dimensional - buffers. - loc: user-visible location of the operation. - ip: insertion point. - """ - indices_resolved = [] if indices is None else _get_op_results_or_values(indices) - super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py deleted file mode 100644 index 1cc00bdcbf381..0000000000000 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ /dev/null @@ -1,114 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, overload, Union - - -class MemRefAllocaToGlobalOp: - """Specialization for MemRefAllocaToGlobalOp class.""" - - @overload - def __init__( - self, - get_global_type: Type, - global_type: Type, - alloca: Union[Operation, OpView, Value], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): - ... - - def __init__( - self, - get_global_type_or_alloca: Union[Operation, OpView, Type, Value], - global_type_or_none: Optional[Type] = None, - alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None - ): - if isinstance(get_global_type_or_alloca, Type): - get_global_type = get_global_type_or_alloca - global_type = global_type_or_none - alloca = alloca_or_none - else: - get_global_type = transform.AnyOpType.get() - global_type = transform.AnyOpType.get() - alloca = get_global_type_or_alloca - - super().__init__( - get_global_type, - global_type, - alloca, - loc=loc, - ip=ip, - ) - - -class MemRefMultiBufferOp: - """Specialization for MemRefMultiBufferOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - factor: Union[int, IntegerAttr], - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - factor: Union[int, IntegerAttr], - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, - factor_or_none: Optional[Union[int, IntegerAttr]] = None, - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_factor - factor = factor_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - factor = target_or_factor - - super().__init__( - transformed_type, - target, - factor, - skip_analysis=skip_analysis, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py deleted file mode 100644 index c84d23c16ef93..0000000000000 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ /dev/null @@ -1,113 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Union - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from ._ml_program_ops_gen import * - - -ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" -RESULT_ATTRIBUTE_NAME = "res_attrs" - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__( - self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None - ): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = ( - StringAttr.get(str(visibility)) if visibility is not None else None - ) - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError("External function does not have a body") - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError("The function already has an entry block!") - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context - ) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 895c3228139b3..9cca7d659ec8c 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -9,7 +9,6 @@ __all__ = [ "equally_sized_accessor", - "extend_opview_class", "get_default_loc_context", "get_op_result_or_value", "get_op_results_or_values", @@ -18,64 +17,6 @@ ] -def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). A name-based match is attempted first before falling back - to a below mechanism. - - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - # First try to resolve by name. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - # Fall back to a select_opview_mixin hook. - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}" - ) from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator - - def segmented_accessor(elements, raw_segments, idx): """ Returns a slice of elements corresponding to the idx-th segment. diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py deleted file mode 100644 index fc9de0b7f7db6..0000000000000 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ /dev/null @@ -1,271 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import pdl -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Union, Optional, Sequence, Mapping -from ._ods_common import ( - get_op_result_or_value as _get_value, - get_op_results_or_values as _get_values, -) - - -class ApplyNativeConstraintOp: - """Specialization for PDL apply native constraint op class.""" - - def __init__( - self, - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) - - -class ApplyNativeRewriteOp: - """Specialization for PDL apply native rewrite op class.""" - - def __init__( - self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) - - -class AttributeOp: - """Specialization for PDL attribute op class.""" - - def __init__( - self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None, - ): - valueType = valueType if valueType is None else _get_value(valueType) - result = pdl.AttributeType.get() - super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) - - -class EraseOp: - """Specialization for PDL erase op class.""" - - def __init__( - self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - operation = _get_value(operation) - super().__init__(operation, loc=loc, ip=ip) - - -class OperandOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - type = type if type is None else _get_value(type) - result = pdl.ValueType.get() - super().__init__(result, valueType=type, loc=loc, ip=ip) - - -class OperandsOp: - """Specialization for PDL operands op class.""" - - def __init__( - self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - types = types if types is None else _get_value(types) - result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, valueType=types, loc=loc, ip=ip) - - -class OperationOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - name: Optional[Union[str, StringAttr]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, - types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if types is None: - types = [] - if attributes is None: - attributes = {} - if args is None: - args = [] - args = _get_values(args) - attrNames = [] - attrValues = [] - for attrName, attrValue in attributes.items(): - attrNames.append(StringAttr.get(attrName)) - attrValues.append(_get_value(attrValue)) - attrNames = ArrayAttr.get(attrNames) - types = _get_values(types) - result = pdl.OperationType.get() - super().__init__( - result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip - ) - - -class PatternOp: - """Specialization for PDL pattern op class.""" - - def __init__( - self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None, - ): - """Creates an PDL `pattern` operation.""" - super().__init__(benefit, sym_name=name, loc=loc, ip=ip) - self.regions[0].blocks.append() - - @property - def body(self): - """Return the body (block) of the pattern.""" - return self.regions[0].blocks[0] - - -class ReplaceOp: - """Specialization for PDL replace op class.""" - - def __init__( - self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - loc=None, - ip=None, - ): - if with_values is None: - with_values = [] - op = _get_value(op) - with_op = with_op if with_op is None else _get_value(with_op) - with_values = _get_values(with_values) - super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) - - -class ResultOp: - """Specialization for PDL result op class.""" - - def __init__( - self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) - - -class ResultsOp: - """Specialization for PDL results op class.""" - - def __init__( - self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - super().__init__(result, parent, index=index, loc=loc, ip=ip) - - -class RewriteOp: - """Specialization for PDL rewrite op class.""" - - def __init__( - self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - root = root if root is None else _get_value(root) - args = _get_values(args) - super().__init__(args, root=root, name=name, loc=loc, ip=ip) - - def add_body(self): - """Add body (block) to the rewrite.""" - self.regions[0].blocks.append() - return self.body - - @property - def body(self): - """Return the body (block) of the rewrite.""" - return self.regions[0].blocks[0] - - -class TypeOp: - """Specialization for PDL type op class.""" - - def __init__( - self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None - ): - result = pdl.TypeType.get() - super().__init__(result, constantType=constantType, loc=loc, ip=ip) - - -class TypesOp: - """Specialization for PDL types op class.""" - - def __init__( - self, - constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, - *, - loc=None, - ip=None, - ): - if constantTypes is None: - constantTypes = [] - result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py deleted file mode 100644 index 89cc8a19895c7..0000000000000 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ /dev/null @@ -1,107 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - -from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, -) - - -class ForOp: - """Specialization for the SCF for op class.""" - - def __init__( - self, - lower_bound, - upper_bound, - step, - iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - """Creates an SCF `for` operation. - - - `lower_bound` is the value to use as lower bound of the loop. - - `upper_bound` is the value to use as upper bound of the loop. - - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments or an operation - producing them as results. - """ - if iter_args is None: - iter_args = [] - iter_args = _get_op_results_or_values(iter_args) - - results = [arg.type for arg in iter_args] - super().__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] - ] - + list(iter_args), - loc=loc, - ip=ip, - ) - ) - self.regions[0].blocks.append(self.operands[0].type, *results) - - @property - def body(self): - """Returns the body (block) of the loop.""" - return self.regions[0].blocks[0] - - @property - def induction_variable(self): - """Returns the induction variable of the loop.""" - return self.body.arguments[0] - - @property - def inner_iter_args(self): - """Returns the loop-carried arguments usable within the loop. - - To obtain the loop-carried operands, use `iter_args`. - """ - return self.body.arguments[1:] - - -class IfOp: - """Specialization for the SCF if op class.""" - - def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): - """Creates an SCF `if` operation. - - - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. - - `hasElse` determines whether the if operation has the else branch. - """ - operands = [] - operands.append(cond) - results = [] - results.extend(results_) - super().__init__( - self.build_generic( - regions=2, results=results, operands=operands, loc=loc, ip=ip - ) - ) - self.regions[0].blocks.append(*[]) - if hasElse: - self.regions[1].blocks.append(*[]) - - @property - def then_block(self): - """Returns the then block of the if operation.""" - return self.regions[0].blocks[0] - - @property - def else_block(self): - """Returns the else block of the if operation.""" - return self.regions[1].blocks[0] diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py deleted file mode 100644 index 3757a3d3b4cce..0000000000000 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ /dev/null @@ -1,759 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import List, Optional, Sequence, Tuple, Union, overload - -StaticIntLike = Union[int, IntegerAttr] -ValueLike = Union[Operation, OpView, Value] -MixedInt = Union[StaticIntLike, ValueLike] - -IntOrAttrList = Sequence[Union[IntegerAttr, int]] -OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] - -BoolOrAttrList = Sequence[Union[BoolAttr, bool]] -OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] - -MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] - -DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] - - -def _dispatch_dynamic_index_list( - indices: Union[DynamicIndexList, ArrayAttr], -) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: - """Dispatches a list of indices to the appropriate form. - - This is similar to the custom `DynamicIndexList` directive upstream: - provided indices may be in the form of dynamic SSA values or static values, - and they may be scalable (i.e., as a singleton list) or not. This function - dispatches each index into its respective form. It also extracts the SSA - values and static indices from various similar structures, respectively. - """ - dynamic_indices = [] - static_indices = [ShapedType.get_dynamic_size()] * len(indices) - scalable_indices = [False] * len(indices) - - # ArrayAttr: Extract index values. - if isinstance(indices, ArrayAttr): - indices = [idx for idx in indices] - - def process_nonscalable_index(i, index): - """Processes any form of non-scalable index. - - Returns False if the given index was scalable and thus remains - unprocessed; True otherwise. - """ - if isinstance(index, int): - static_indices[i] = index - elif isinstance(index, IntegerAttr): - static_indices[i] = index.value # pytype: disable=attribute-error - elif isinstance(index, (Operation, Value, OpView)): - dynamic_indices.append(index) - else: - return False - return True - - # Process each index at a time. - for i, index in enumerate(indices): - if not process_nonscalable_index(i, index): - # If it wasn't processed, it must be a scalable index, which is - # provided as a Sequence of one value, so extract and process that. - scalable_indices[i] = True - assert len(index) == 1 - ret = process_nonscalable_index(i, index[0]) - assert ret - - return dynamic_indices, static_indices, scalable_indices - - -# Dispatches `MixedValues` that all represents integers in various forms into -# the following three categories: -# - `dynamic_values`: a list of `Value`s, potentially from op results; -# - `packed_values`: a value handle, potentially from an op result, associated -# to one or more payload operations of integer type; -# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python -# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. -# The input is in the form for `packed_values`, only that result is set and the -# other two are empty. Otherwise, the input can be a mix of the other two forms, -# and for each dynamic value, a special value is added to the `static_values`. -def _dispatch_mixed_values( - values: MixedValues, -) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: - dynamic_values = [] - packed_values = None - static_values = None - if isinstance(values, ArrayAttr): - static_values = values - elif isinstance(values, (Operation, Value, OpView)): - packed_values = values - else: - static_values = [] - for size in values or []: - if isinstance(size, int): - static_values.append(size) - else: - static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(size) - static_values = DenseI64ArrayAttr.get(static_values) - - return (dynamic_values, packed_values, static_values) - - -def _get_value_or_attribute_value( - value_or_attr: Union[any, Attribute, ArrayAttr] -) -> any: - if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): - return value_or_attr.value - if isinstance(value_or_attr, ArrayAttr): - return _get_value_list(value_or_attr) - return value_or_attr - - -def _get_value_list( - sequence_or_array_attr: Union[Sequence[any], ArrayAttr] -) -> Sequence[any]: - return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] - - -def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: - if values is None: - return None - - # Turn into a Python list of Python ints. - values = _get_value_list(values) - - # Make an ArrayAttr of IntegerAttrs out of it. - return ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] - ) - - -def _get_int_array_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] -) -> ArrayAttr: - """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - - The input has to be a collection of collection of integers, where any - Python Sequence and ArrayAttr are admissible collections and Python ints and - any IntegerAttr are admissible integers. Both levels of collections are - turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. - If the input is None, an empty ArrayAttr is returned. - """ - if values is None: - return None - - # Make sure the outer level is a list. - values = _get_value_list(values) - - # The inner level is now either invalid or a mixed sequence of ArrayAttrs and - # Sequences. Make sure the nested values are all lists. - values = [_get_value_list(nested) for nested in values] - - # Turn each nested list into an ArrayAttr. - values = [_get_int_array_attr(nested) for nested in values] - - # Turn the outer list into an ArrayAttr. - return ArrayAttr.get(values) - - -class BufferizeToAllocationOp: - """Specialization for BufferizeToAllocationOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - memory_space: Optional[Union[int, str, Attribute]] = None, - memcpy_op: Optional[str] = None, - alloc_op: Optional[str] = None, - bufferize_destination_only: Optional[bool] = None, - loc=None, - ip=None, - ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - - if isinstance(memory_space, int): - memory_space = str(memory_space) - if isinstance(memory_space, str): - memory_space = Attribute.parse(memory_space) - - super().__init__( - allocated_buffer_type, - new_ops_type, - target, - memory_space=memory_space, - memcpy_op=memcpy_op, - alloc_op=alloc_op, - bufferize_destination_only=bufferize_destination_only, - loc=loc, - ip=ip, - ) - - -class DecomposeOp: - """Specialization for DecomposeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - transformed_type = transform.AnyOpType.get() - super().__init__(transformed_type, target, loc=loc, ip=ip) - - -class FuseIntoContainingOp: - """Specialization for FuseIntoContainingOp class.""" - - @overload - def __init__( - self, - fused_op_type: Type, - new_containing_op_type: Type, - producer_op: Union[Operation, OpView, Value], - containing_op: Union[Operation, OpView, Value], - *, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - producer_op: Union[Operation, OpView, Value], - containing_op: Union[Operation, OpView, Value], - *, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], - new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], - producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, - containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None, - ): - if isinstance(fused_op_type_or_producer_op, Type): - if not isinstance(new_containing_op_type_or_containing_op, Type): - raise TypeError( - "If 'fused_op_type_or_producer_op' is a type, then " - "'new_containing_op_type_or_containing_op' is expected " - "to be one as well." - ) - fused_op_type = fused_op_type_or_producer_op - new_containing_op_type = new_containing_op_type_or_containing_op - producer_op = producer_op_or_none - containing_op = containing_op_or_none - else: - fused_op_type = transform.AnyOpType.get() - new_containing_op_type = transform.AnyOpType.get() - producer_op = fused_op_type_or_producer_op - containing_op = new_containing_op_type_or_containing_op - - super().__init__( - fused_op_type, - new_containing_op_type, - producer_op, - containing_op, - loc=loc, - ip=ip, - ) - - -class GeneralizeOp: - """Specialization for GeneralizeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - transformed_type = transform.AnyOpType.get() - super().__init__(transformed_type, target, loc=loc, ip=ip) - - -class InterchangeOp: - """Specialization for InterchangeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - transformed_type = transform.AnyOpType.get() - super().__init__( - transformed_type, - target, - iterator_interchange=iterator_interchange, - loc=loc, - ip=ip, - ) - - -class MapCopyToThreadsOp: - """Specialization for MapCopyToThreadsOp class.""" - - @overload - def __init__( - self, - forall_op_type: Type, - tiled_op_type: Type, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - def __init__( - self, - forall_op_type_or_target: Union[Operation, OpView, Type, Value], - tiled_op_type_or_none: Optional[Type] = None, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - if isinstance(forall_op_type_or_target, Type): - forall_op_type = forall_op_type_or_target - tiled_op_type = tiled_op_type_or_none - target = target_or_none - else: - forall_op_type = transform.AnyOpType.get() - tiled_op_type = transform.AnyOpType.get() - target = forall_op_type_or_target - - super().__init__( - forall_op_type, - tiled_op_type, - target, - total_num_threads=total_num_threads, - desired_bit_alignment=desired_bit_alignment, - loc=loc, - ip=ip, - ) - - -class VectorizeOp: - """Specialization for VectorizeOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - *, - vectorize_nd_extract: Optional[bool] = None, - scalable_sizes: OptionalBoolList = None, - static_vector_sizes: OptionalIntList = None, - loc=None, - ip=None, - ): - if ( - scalable_sizes is None - and static_vector_sizes is None - and vector_sizes is None - ): - dynamic_vector_sizes = [] - elif scalable_sizes is None and static_vector_sizes is None: - ( - dynamic_vector_sizes, - static_vector_sizes, - scalable_sizes, - ) = _dispatch_dynamic_index_list(vector_sizes) - elif scalable_sizes is None or static_vector_sizes is None: - raise TypeError( - "'scalable_sizes' and 'static_vector_sizes' must either both " - "be given explicitly or both be given as part of 'vector_sizes'." - ) - else: - dynamic_vector_sizes = vector_sizes - - super().__init__( - target, - vector_sizes=dynamic_vector_sizes, - static_vector_sizes=static_vector_sizes, - scalable_sizes=scalable_sizes, - vectorize_nd_extract=vectorize_nd_extract, - loc=loc, - ip=ip, - ) - - -class MatchOp: - """Specialization for MatchOp class.""" - - @overload - @classmethod - def match_op_names( - cls, - target: Union[Operation, Value], - names: Union[str, Sequence[str]], - *, - loc=None, - ip=None, - ): - ... - - @overload - @classmethod - def match_op_names( - cls, - result_type: Type, - target: Union[Operation, Value], - names: Union[str, Sequence[str]], - *, - loc=None, - ip=None, - ): - ... - - @classmethod - def match_op_names( - cls, - result_type_or_target: Union[Type, Operation, Value], - target_or_names: Union[Operation, Value, Sequence[str], str], - names_or_none: Optional[Union[Sequence[str], str]] = None, - *, - loc=None, - ip=None, - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_names - names = names_or_none - else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - names = target_or_names - - if isinstance(names, str): - names = [names] - - return cls( - result_type, - target, - ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), - loc=loc, - ip=ip, - ) - - -class MultiTileSizesOp: - """Specialization for MultiTileSizesOp class.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, - loc=None, - ip=None, - ): - super().__init__( - result_type, - result_type, - result_type, - target, - dimension=dimension, - target_size=target_size, - divisor=divisor, - loc=loc, - ip=ip, - ) - - -class PadOp: - """Specialization for PadOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, - padding_dimensions: OptionalIntList = None, - pad_to_multiple_of: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[ - Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] - ] = None, - copy_back_op: Optional[Union[str, StringAttr]] = None, - loc=None, - ip=None, - ): - transpose_paddings = _get_int_array_array_attr(transpose_paddings) - - any_op_type = transform.AnyOpType.get() - super().__init__( - any_op_type, - any_op_type, - any_op_type, - target, - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pad_to_multiple_of=pad_to_multiple_of, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings, - copy_back_op=copy_back_op, - loc=loc, - ip=ip, - ) - - -class ScalarizeOp: - """Specialization for ScalarizeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - result_type = transform.AnyOpType.get() - super().__init__(result_type, target, loc=loc, ip=ip) - - -class SplitOp: - """Specialization for SplitOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None, - ): - if isinstance(split_point, int): - static_split_point = split_point - dynamic_split_point = None - else: - static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = split_point - - super().__init__( - target.type, - target.type, - target, - dimension=dimension, - static_split_point=static_split_point, - dynamic_split_point=dynamic_split_point, - loc=loc, - ip=ip, - ) - - -class TileUsingForOp: - """Specialization for TileUsingForOp class.""" - - @overload - def __init__( - self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ( - dynamic_sizes, - static_sizes, - scalable_sizes, - ) = _dispatch_dynamic_index_list(sizes) - - num_loops = sum(v if v == 0 else 1 for v in static_sizes) - - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert ( - target_or_none is None - ), "Cannot construct TileUsingForOp with two targets." - else: - loop_types = ( - ([loop_types_or_target] * num_loops) - if isinstance(loop_types_or_target, Type) - else loop_types_or_target - ) - target = target_or_none - - super().__init__( - target.type, - loop_types, - target, - dynamic_sizes=dynamic_sizes, - static_sizes=static_sizes, - interchange=interchange, - scalable_sizes=scalable_sizes, - loc=loc, - ip=ip, - ) - - -class TileUsingForallOp: - """Specialization for TileUsingForallOp class.""" - - @overload - def __init__( - self, - loops_type: Type, - tiled_op_type: Type, - target: Union[Operation, Value, OpView], - *, - num_threads: Optional[MixedValues] = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - num_threads: Optional[MixedValues] = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loops_type_or_target: Union[ - Type, Union[Operation, Value, OpView] # loops_type - ], # target - tiled_op_type_or_none: Optional[Type] = None, - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - num_threads: MixedValues = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - # `Type` arguments in the front are optional: add default values to front. - if isinstance(loops_type_or_target, Type): - # First overload: type arguments provided. - if not isinstance(tiled_op_type_or_none, Type): - raise TypeError( - "If 'loops_type_or_target' is a type, then " - "'tiled_op_type_or_none' is expected to be one as well." - ) - loops_type = loops_type_or_target - tiled_op_type = tiled_op_type_or_none - target = target_or_none - else: - # Last overload: type arguments missing. - loops_type = transform.AnyOpType.get() - tiled_op_type = transform.AnyOpType.get() - target = loops_type_or_target - - # Unpack mixed num_threads. - ( - dynamic_num_threads, - packed_num_threads, - num_threads_attr, - ) = _dispatch_mixed_values(num_threads) - - # Unpack mixed tile_sizes. - ( - dynamic_tile_sizes, - packed_tile_sizes, - tile_sizes_attr, - ) = _dispatch_mixed_values(tile_sizes) - - super().__init__( - loops_type, - tiled_op_type, - target=target, - tile_sizes=dynamic_tile_sizes, - packed_tile_sizes=packed_tile_sizes, - static_tile_sizes=tile_sizes_attr, - num_threads=dynamic_num_threads, - packed_num_threads=packed_num_threads, - static_num_threads=num_threads_attr, - mapping=mapping, - loc=loc, - ip=ip, - ) - - -class VectorizeChildrenAndApplyPatternsOp: - """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - disable_multi_reduction_to_contract_patterns: bool = False, - disable_transfer_permutation_map_lowering_patterns: bool = False, - vectorize_nd_extract: bool = False, - vectorize_padding: bool = False, - loc=None, - ip=None, - ): - transformed_type = transform.AnyOpType.get() - super().__init__( - transformed_type, - target, - disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, - disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, - vectorize_nd_extract=vectorize_nd_extract, - vectorize_padding=vectorize_padding, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py deleted file mode 100644 index 09b9ec68db7d9..0000000000000 --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ /dev/null @@ -1,44 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Any, Optional, Sequence, Union -from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, -) - - -class EmptyOp: - """Extends the tensor.empty op.""" - - def __init__( - self, - sizes: Sequence[Union[int, Value]], - element_type: Type, - *, - loc=None, - ip=None - ): - """Constructs an `empty` with mixed static/dynamic sizes.""" - # TODO: Refactor the EmptyOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - dynamic_sizes = [] - static_sizes = [] - for s in sizes: - if isinstance(s, int): - static_sizes.append(s) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) - op = self.build_generic( - results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip - ) - OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py deleted file mode 100644 index 996093fbc913e..0000000000000 --- a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py +++ /dev/null @@ -1,64 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, overload, Union - - -class MakeLoopIndependentOp: - """Specialization for MakeLoopIndependentOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - num_loops: Union[int, IntegerAttr], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - num_loops: Union[int, IntegerAttr], - *, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, - num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, - *, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_num_loops - num_loops = num_loops_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - num_loops = target_or_num_loops - - super().__init__( - transformed_type, - target, - num_loops, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py deleted file mode 100644 index b1e7b892536f4..0000000000000 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ /dev/null @@ -1,176 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class CastOp: - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None, - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) - - -class ApplyPatternsOp: - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - loc=None, - ip=None, - ): - operands = [] - operands.append(_get_op_result_or_value(target)) - super().__init__( - self.build_generic( - attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip, - ) - ) - self.regions[0].blocks.append() - - @property - def patterns(self) -> Block: - return self.regions[0].blocks[0] - - -class testGetParentOp: - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) - - -class MergeHandlesOp: - def __init__( - self, - handles: Sequence[Union[Operation, Value]], - *, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h) for h in handles], - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) - - -class ReplicateOp: - def __init__( - self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h).type for h in handles], - _get_op_result_or_value(pattern), - [_get_op_result_or_value(h) for h in handles], - loc=loc, - ip=ip, - ) - - -class SequenceOp: - def __init__( - self, - failure_propagation_mode, - results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[ - Union[Sequence[Value], Sequence[Type], Operation, OpView] - ] = None, - ): - root = ( - _get_op_result_or_value(target) - if isinstance(target, (Operation, Value)) - else None - ) - root_type = root.type if not isinstance(target, Type) else target - - if extra_bindings is None: - extra_bindings = [] - if isinstance(extra_bindings, (Operation, OpView)): - extra_bindings = _get_op_results_or_values(extra_bindings) - - extra_binding_types = [] - if len(extra_bindings) != 0: - if isinstance(extra_bindings[0], Type): - extra_binding_types = extra_bindings - extra_bindings = [] - else: - extra_binding_types = [v.type for v in extra_bindings] - - super().__init__( - results_=results, - failure_propagation_mode=failure_propagation_mode, - root=root, - extra_bindings=extra_bindings, - ) - self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - @property - def bodyExtraArgs(self) -> BlockArgumentList: - return self.body.arguments[1:] - - -class YieldOp: - def __init__( - self, - operands: Optional[Union[Operation, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - if operands is None: - operands = [] - super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py deleted file mode 100644 index c4e4b4b4254b0..0000000000000 --- a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py +++ /dev/null @@ -1,55 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Union - -class PDLMatchOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - pattern_name, - loc=loc, - ip=ip, - ) - - -class WithPDLPatternsOp: - - def __init__(self, - target: Union[Operation, Value, Type], - *, - loc=None, - ip=None): - root = _get_op_result_or_value(target) if not isinstance(target, - Type) else None - root_type = target if isinstance(target, Type) else root.type - super().__init__(root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(root_type) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 8a2a64c7c40d1..1eaccfa73a85c 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -1,5 +1,50 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._affine_ops_gen import * +from ._affine_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineStoreOp(AffineStoreOp): + """Specialization for the Affine store operation.""" + + def __init__( + self, + value: Union[Operation, OpView, Value], + memref: Union[Operation, OpView, Value], + map: AffineMap = None, + *, + map_operands=None, + loc=None, + ip=None, + ): + """Creates an affine store operation. + + - `value`: the value to store into the memref. + - `memref`: the buffer to store into. + - `map`: the affine map that maps the map_operands to the index of the + memref. + - `map_operands`: the list of arguments to substitute the dimensions, + then symbols in the affine map, in increasing order. + """ + map = map if map is not None else [] + map_operands = map_operands if map_operands is not None else [] + indicies = [_get_op_result_or_value(op) for op in map_operands] + _ods_successors = None + super().__init__( + value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip + ) diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index fb13beb63ca66..83aca0d58bf2c 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,4 +3,75 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + def __init__( + self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + ): + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip + ) + + @property + def type(self): + return self.results[0].type + + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py index 759b6aa24a9ff..0ce5448ace4b1 100644 --- a/mlir/python/mlir/dialects/bufferization.py +++ b/mlir/python/mlir/dialects/bufferization.py @@ -3,4 +3,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._bufferization_ops_gen import * +from ._bufferization_ops_gen import _Dialect from ._bufferization_enum_gen import * + +try: + from typing import Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context, _cext as _ods_cext + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AllocTensorOp(AllocTensorOp): + """Extends the bufferization.alloc_tensor op.""" + + def __init__( + self, + tensor_type: Type, + dynamic_sizes: Sequence[Value], + copy: Value, + size_hint: Value, + escape: BoolAttr, + *, + loc=None, + ip=None, + ): + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" + super().__init__( + tensor_type, + dynamic_sizes, + copy=copy, + size_hint=size_hint, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py index 30279e1611f99..b71cc2466d464 100644 --- a/mlir/python/mlir/dialects/builtin.py +++ b/mlir/python/mlir/dialects/builtin.py @@ -3,3 +3,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._builtin_ops_gen import * +from ._builtin_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ModuleOp(ModuleOp): + """Specialization for the module op class.""" + + def __init__(self, *, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + body = self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index dc554c22173bc..9c6c4c9092c7a 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -3,3 +3,326 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._func_ops_gen import * +from ._func_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) + + import inspect + + from typing import Any, List, Optional, Sequence, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): + super().__init__(result, value, loc=loc, ip=ip) + + @property + def type(self): + return self.results[0].type + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func( + FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None, + ): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and ( + param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY + ): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results + ) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None." + ) + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get( + inputs=inputs, results=return_types + ) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp( + return_types, FlatSymbolRefAttr.get(symbol_name), call_args + ) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CallOp(CallOp): + """Specialization for the call op class.""" + + def __init__( + self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None, + ): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = func.FuncOp("foo", ...) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}" + ) + if arguments is not None: + raise ValueError( + "unexpected third argument when constructing a call" + + "to a function" + ) + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, context=_get_default_loc_context(loc) + ), + argumentsOrCallee, + loc=loc, + ip=ip, + ) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" + ) + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip + ) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc) + ), + arguments, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 6f9d72164429e..f91fc8b716008 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -310,7 +310,7 @@ def emit_named_structured_op( ) # Set the index attributes used to compute the indexing maps. - named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + named_op = getattr(linalg, op_class_name)(result_types, ins, outs) for name, value in index_attrs.items(): named_op.operation.attributes[name] = value diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index a8f8f8e0fbd68..19734a80a107b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -296,35 +296,39 @@ def quantized_matmul( @linalg_structured_op -def matmul_transpose_a(A=TensorDef(T1, S.K, S.N), - B=TensorDef(T2, S.K, S.M), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs with lhs operand - transposed. +def matmul_transpose_a( + A=TensorDef(T1, S.K, S.N), + B=TensorDef(T2, S.K, S.M), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs with lhs operand + transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) @linalg_structured_op -def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.N, S.K), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs with rhs operand - transposed. +def matmul_transpose_b( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) @linalg_structured_op @@ -390,36 +394,41 @@ def batch_matmul( @linalg_structured_op -def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs where lhs operand - has its non-batch dimensions transposed. +def batch_matmul_transpose_a( + A=TensorDef(T1, Batch, S.K, S.M), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs where lhs operand + has its non-batch dimensions transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \ - * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) @linalg_structured_op -def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.N, S.K), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs where rhs operand - has its non-batch dimensions transposed. +def batch_matmul_transpose_b( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, - D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.n, D.k]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k] + ) @linalg_structured_op diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 3afb6a70cb9e0..111ad2178703d 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -3,3 +3,41 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._memref_ops_gen import * +from ._memref_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoadOp(LoadOp): + """Specialization for the MemRef load operation.""" + + def __init__( + self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + """Creates a memref load operation. + + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + indices_resolved = [] if indices is None else _get_op_results_or_values(indices) + super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py index a654529b4bb88..dfb6d7f2c03b1 100644 --- a/mlir/python/mlir/dialects/ml_program.py +++ b/mlir/python/mlir/dialects/ml_program.py @@ -2,4 +2,118 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Union + from ._ml_program_ops_gen import * +from ._ml_program_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index dda2b7d652196..a8d9c56f4233d 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -3,4 +3,289 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._pdl_ops_gen import * +from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * + + +try: + from ..ir import * + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional, Sequence, Mapping +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, + _cext as _ods_cext, +) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyNativeConstraintOp(ApplyNativeConstraintOp): + """Specialization for PDL apply native constraint op class.""" + + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(name, args, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyNativeRewriteOp(ApplyNativeRewriteOp): + """Specialization for PDL apply native rewrite op class.""" + + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(results, name, args, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + """Specialization for PDL attribute op class.""" + + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): + valueType = valueType if valueType is None else _get_value(valueType) + result = pdl.AttributeType.get() + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EraseOp(EraseOp): + """Specialization for PDL erase op class.""" + + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandOp(OperandOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, valueType=type, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandsOp(OperandsOp): + """Specialization for PDL operands op class.""" + + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, valueType=types, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + args = _get_values(args) + attrNames = [] + attrValues = [] + for attrName, attrValue in attributes.items(): + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__( + result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PatternOp(PatternOp): + """Specialization for PDL pattern op class.""" + + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + """Creates an PDL `pattern` operation.""" + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplaceOp(ReplaceOp): + """Specialization for PDL replace op class.""" + + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ResultOp(ResultOp): + """Specialization for PDL result op class.""" + + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ResultsOp(ResultsOp): + """Specialization for PDL results op class.""" + + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + super().__init__(result, parent, index=index, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class RewriteOp(RewriteOp): + """Specialization for PDL rewrite op class.""" + + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + root = root if root is None else _get_value(root) + args = _get_values(args) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + """Specialization for PDL type op class.""" + + def __init__( + self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None + ): + result = pdl.TypeType.get() + super().__init__(result, constantType=constantType, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypesOp(TypesOp): + """Specialization for PDL types op class.""" + + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 8465af048a280..6579e02d8549e 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType +from .._mlir_libs._mlirPythonTest import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, +) def register_python_test_dialect(context, load=True): diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 49685ca2271fc..43ad9f4e2d65f 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -2,11 +2,122 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional, Sequence from ._scf_ops_gen import * +from ._scf_ops_gen import _Dialect from .arith import constant -from ..ir import * + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +_ForOp = ForOp + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForOp(_ForOp): + """Specialization for the SCF for op class.""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super(_ForOp, self).__init__( + self.build_generic( + regions=1, + results=results, + operands=[ + _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] + ] + + list(iter_args), + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append(self.operands[0].type, *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] + + +_IfOp = IfOp + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IfOp(_IfOp): + """Specialization for the SCF if op class.""" + + def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super(_IfOp, self).__init__( + self.build_generic( + regions=2, results=results, operands=operands, loc=loc, ip=ip + ) + ) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] def for_( diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 26edf6b6436da..67248748eaf3a 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -3,3 +3,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._tensor_ops_gen import * +from ._tensor_ops_gen import _Dialect + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Sequence, Union +from ._ods_common import _cext as _ods_cext + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyOp(EmptyOp): + """Extends the tensor.empty op.""" + + def __init__( + self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None, + ): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index b020ad35fcf06..f7a2026e800ae 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -4,4 +4,174 @@ from .._transform_enum_gen import * from .._transform_ops_gen import * +from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CastOp(CastOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyPatternsOp(ApplyPatternsOp): + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + super().__init__(target, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GetParentOp(GetParentOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MergeHandlesOp(MergeHandlesOp): + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplicateOp(ReplicateOp): + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h).type for h in handles], + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SequenceOp(SequenceOp): + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + ): + root = ( + _get_op_result_or_value(target) + if isinstance(target, (Operation, Value)) + else None + ) + root_type = root.type if not isinstance(target, Type) else target + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode, + root=root, + extra_bindings=extra_bindings, + ) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class YieldOp(YieldOp): + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py index eb77b746cf864..485a8a36b6305 100644 --- a/mlir/python/mlir/dialects/transform/bufferization.py +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -3,3 +3,132 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._bufferization_transform_ops_gen import * +from .._bufferization_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from enum import Enum +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp): + """Specialization for EmptyTensorToAllocTensorOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.OperationType.get("bufferization.alloc_tensor") + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OneShotBufferizeOp(OneShotBufferizeOp): + """Specialization for OneShotBufferizeOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + allow_return_allocs_from_loops=allow_return_allocs_from_loops, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py index 8c3de0de7ea3f..00cf0840eeae9 100644 --- a/mlir/python/mlir/dialects/transform/gpu.py +++ b/mlir/python/mlir/dialects/transform/gpu.py @@ -3,3 +3,128 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_transform_ops_gen import * +from .._gpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapForallToBlocks(MapForallToBlocks): + """Specialization for MapForallToBlocks class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Type, Value], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + + super().__init__( + result_type, + target, + grid_dims=grid_dims, + generate_gpu_launch=generate_gpu_launch, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapNestedForallToThreads(MapNestedForallToThreads): + """Specialization for MapNestedForallToThreads class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Value, Type], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + block_dims: Optional[Union[Sequence[int], Attribute]] = None, + warp_size: Optional[Union[Sequence[int], Attribute]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = result_type_or_target.type + target = result_type_or_target + super().__init__( + result_type, + target, + block_dims=block_dims, + warp_size=warp_size, + sync_after_distribute=sync_after_distribute, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py index 86f72788d86c3..6c89025f41383 100644 --- a/mlir/python/mlir/dialects/transform/loop.py +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -3,3 +3,143 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._loop_transform_ops_gen import * +from .._loop_transform_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GetParentForOp(GetParentForOp): + """Extension for GetParentForOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 + super().__init__( + result_type, + _get_op_result_or_value(target), + num_loops=num_loops, + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopOutlineOp(LoopOutlineOp): + """Extension for LoopOutlineOp.""" + + def __init__( + self, + function_type: Type, + call_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): + super().__init__( + function_type, + call_type, + _get_op_result_or_value(target), + func_name=( + func_name + if isinstance(func_name, StringAttr) + else StringAttr.get(func_name) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPeelOp(LoopPeelOp): + """Extension for LoopPeelOp.""" + + def __init__( + self, + main_loop_type: Type, + remainder_loop_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): + super().__init__( + main_loop_type, + remainder_loop_type, + _get_op_result_or_value(target), + fail_if_already_divisible=( + fail_if_already_divisible + if isinstance(fail_if_already_divisible, BoolAttr) + else BoolAttr.get(fail_if_already_divisible) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPipelineOp(LoopPipelineOp): + """Extension for LoopPipelineOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 + super().__init__( + result_type, + _get_op_result_or_value(target), + iteration_interval=iteration_interval, + read_latency=read_latency, + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopUnrollOp(LoopUnrollOp): + """Extension for LoopUnrollOp.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): + super().__init__( + _get_op_result_or_value(target), + factor=factor, + ip=ip, + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py index 1ff04ef6a60a1..56ea61eb817f8 100644 --- a/mlir/python/mlir/dialects/transform/memref.py +++ b/mlir/python/mlir/dialects/transform/memref.py @@ -3,3 +3,118 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._memref_transform_ops_gen import * +from .._memref_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp): + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + alloca = get_global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + alloca, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefMultiBufferOp(MemRefMultiBufferOp): + """Specialization for MemRefMultiBufferOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, + factor_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_factor + factor = factor_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + factor = target_or_factor + + super().__init__( + transformed_type, + target, + factor, + skip_analysis=skip_analysis, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py index b1515287a3f1f..bb5fa7ffd3065 100644 --- a/mlir/python/mlir/dialects/transform/pdl.py +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -3,3 +3,53 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._transform_pdl_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PDLMatchOp(PDLMatchOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class WithPDLPatternsOp(WithPDLPatternsOp): + def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index cb3812301dbd4..284c93823acbd 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,4 +3,777 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_ops_gen import _Dialect from .._structured_transform_enum_gen import * + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Tuple, Union, overload + +StaticIntLike = Union[int, IntegerAttr] +ValueLike = Union[Operation, OpView, Value] +MixedInt = Union[StaticIntLike, ValueLike] + +IntOrAttrList = Sequence[Union[IntegerAttr, int]] +OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = Sequence[Union[BoolAttr, bool]] +OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: Union[DynamicIndexList, ArrayAttr], +) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: Union[Sequence[any], ArrayAttr] +) -> Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of collection of integers, where any + Python Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class BufferizeToAllocationOp(BufferizeToAllocationOp): + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[Union[int, str, Attribute]] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class DecomposeOp(DecomposeOp): + """Specialization for DecomposeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseIntoContainingOp(FuseIntoContainingOp): + """Specialization for FuseIntoContainingOp class.""" + + @overload + def __init__( + self, + fused_op_type: Type, + new_containing_op_type: Type, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], + new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], + producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(fused_op_type_or_producer_op, Type): + if not isinstance(new_containing_op_type_or_containing_op, Type): + raise TypeError( + "If 'fused_op_type_or_producer_op' is a type, then " + "'new_containing_op_type_or_containing_op' is expected " + "to be one as well." + ) + fused_op_type = fused_op_type_or_producer_op + new_containing_op_type = new_containing_op_type_or_containing_op + producer_op = producer_op_or_none + containing_op = containing_op_or_none + else: + fused_op_type = transform.AnyOpType.get() + new_containing_op_type = transform.AnyOpType.get() + producer_op = fused_op_type_or_producer_op + containing_op = new_containing_op_type_or_containing_op + + super().__init__( + fused_op_type, + new_containing_op_type, + producer_op, + containing_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GeneralizeOp(GeneralizeOp): + """Specialization for GeneralizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InterchangeOp(InterchangeOp): + """Specialization for InterchangeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + iterator_interchange=iterator_interchange, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapCopyToThreadsOp(MapCopyToThreadsOp): + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeOp(VectorizeOp): + """Specialization for VectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + *, + vectorize_nd_extract: Optional[bool] = None, + scalable_sizes: OptionalBoolList = None, + static_vector_sizes: OptionalIntList = None, + loc=None, + ip=None, + ): + if ( + scalable_sizes is None + and static_vector_sizes is None + and vector_sizes is None + ): + dynamic_vector_sizes = [] + elif scalable_sizes is None and static_vector_sizes is None: + ( + dynamic_vector_sizes, + static_vector_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(vector_sizes) + elif scalable_sizes is None or static_vector_sizes is None: + raise TypeError( + "'scalable_sizes' and 'static_vector_sizes' must either both " + "be given explicitly or both be given as part of 'vector_sizes'." + ) + else: + dynamic_vector_sizes = vector_sizes + + super().__init__( + target, + vector_sizes=dynamic_vector_sizes, + static_vector_sizes=static_vector_sizes, + scalable_sizes=scalable_sizes, + vectorize_nd_extract=vectorize_nd_extract, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MatchOp(MatchOp): + """Specialization for MatchOp class.""" + + @overload + @classmethod + def match_op_names( + cls, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + if isinstance(names, str): + names = [names] + + return cls( + result_type, + target, + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MultiTileSizesOp(MultiTileSizesOp): + """Specialization for MultiTileSizesOp class.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + super().__init__( + result_type, + result_type, + result_type, + target, + dimension=dimension, + target_size=target_size, + divisor=divisor, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PadOp(PadOp): + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + padding_dimensions: OptionalIntList = None, + pad_to_multiple_of: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + copy_back_op: Optional[Union[str, StringAttr]] = None, + loc=None, + ip=None, + ): + transpose_paddings = _get_int_array_array_attr(transpose_paddings) + + any_op_type = transform.AnyOpType.get() + super().__init__( + any_op_type, + any_op_type, + any_op_type, + target, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pad_to_multiple_of=pad_to_multiple_of, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings, + copy_back_op=copy_back_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ScalarizeOp(ScalarizeOp): + """Specialization for ScalarizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + result_type = transform.AnyOpType.get() + super().__init__(result_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SplitOp(SplitOp): + """Specialization for SplitOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): + if isinstance(split_point, int): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = ShapedType.get_dynamic_size() + dynamic_split_point = split_point + + super().__init__( + target.type, + target.type, + target, + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForOp(TileUsingForOp): + """Specialization for TileUsingForOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ( + dynamic_sizes, + static_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(sizes) + + num_loops = sum(v if v == 0 else 1 for v in static_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert ( + target_or_none is None + ), "Cannot construct TileUsingForOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + super().__init__( + target.type, + loop_types, + target, + dynamic_sizes=dynamic_sizes, + static_sizes=static_sizes, + interchange=interchange, + scalable_sizes=scalable_sizes, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForallOp(TileUsingForallOp): + """Specialization for TileUsingForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, Union[Operation, Value, OpView] # loops_type + ], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError( + "If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well." + ) + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + ( + dynamic_num_threads, + packed_num_threads, + num_threads_attr, + ) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + ( + dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr, + ) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp): + """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + disable_multi_reduction_to_contract_patterns: bool = False, + disable_transfer_permutation_map_lowering_patterns: bool = False, + vectorize_nd_extract: bool = False, + vectorize_padding: bool = False, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py index bf52255b3df71..4eb30398f0872 100644 --- a/mlir/python/mlir/dialects/transform/tensor.py +++ b/mlir/python/mlir/dialects/transform/tensor.py @@ -3,3 +3,67 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._tensor_transform_ops_gen import * +from .._tensor_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MakeLoopIndependentOp(MakeLoopIndependentOp): + """Specialization for MakeLoopIndependentOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, + num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_num_loops + num_loops = num_loops_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + num_loops = target_or_num_loops + + super().__init__( + transformed_type, + target, + num_loops, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 0a3b411041b2f..f6b706f9bc8ae 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray): d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) return d + def move_aligned_ptr_by_offset(aligned_ptr, offset): """Moves the supplied ctypes pointer ahead by `offset` elements.""" aligned_addr = ctypes.addressof(aligned_ptr.contents) @@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset): content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) return content_ptr + def unranked_memref_to_numpy(unranked_memref, np_dtype): """Converts unranked memrefs to numpy arrays.""" ctp = as_ctype(np_dtype) @@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype): def ranked_memref_to_numpy(ranked_memref): """Converts ranked memrefs to numpy arrays.""" - content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset) - np_arr = np.ctypeslib.as_array( - content_ptr, shape=ranked_memref[0].shape + content_ptr = move_aligned_ptr_by_offset( + ranked_memref[0].aligned, ranked_memref[0].offset ) + np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape) strided_arr = np.lib.stride_tricks.as_strided( np_arr, np.ctypeslib.as_array(ranked_memref[0].shape), diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index f4a793aee4aa1..6d1c5eab75898 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -33,3 +33,16 @@ def testFastMathFlags(): ) # CHECK: %0 = arith.addf %cst, %cst fastmath : f32 print(r) + + +# CHECK-LABEL: TEST: testArithValueBuilder +@run +def testArithValueBuilder(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32_t = F32Type.get() + + with InsertionPoint(module.body): + a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) + # CHECK: %cst = arith.constant 4.242000e+01 : f32 + print(a) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 49f3a951426d0..c8ef84721090a 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext -from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results +from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results _ods_ir = _ods_cext.ir -try: - from . import _{0}_ops_ext as _ods_ext_module -except ImportError: - _ods_ext_module = None - import builtins from typing import Sequence as _Sequence, Union as _Union @@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect /// {1} is the operation name. constexpr const char *opClassTemplate = R"Py( @_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) class {0}(_ods_ir.OpView): OPERATION_NAME = "{1}" )Py"; @@ -301,17 +295,17 @@ static bool isODSReserved(StringRef str) { /// (does not change the `name` if it already is suitable) and returns the /// modified version. static std::string sanitizeName(StringRef name) { - std::string processed_str = name.str(); + std::string processedStr = name.str(); std::replace_if( - processed_str.begin(), processed_str.end(), + processedStr.begin(), processedStr.end(), [](char c) { return !llvm::isAlnum(c); }, '_'); - if (llvm::isDigit(*processed_str.begin())) - return "_" + processed_str; + if (llvm::isDigit(*processedStr.begin())) + return "_" + processedStr; - if (isPythonReserved(processed_str) || isODSReserved(processed_str)) - return processed_str + "_"; - return processed_str; + if (isPythonReserved(processedStr) || isODSReserved(processedStr)) + return processedStr + "_"; + return processedStr; } static std::string attrSizedTraitForKind(const char *kind) { @@ -853,10 +847,6 @@ populateBuilderRegions(const Operator &op, /// rebuild anew). static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { - // If we are asked to skip default builders, comply. - if (op.skipDefaultBuilders()) - return {}; - llvm::SmallVector builderArgs; llvm::SmallVector builderLines; llvm::SmallVector operandArgNames; @@ -989,9 +979,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) { static void emitValueBuilder(const Operator &op, llvm::SmallVector functionArgs, raw_ostream &os) { - // If we are asked to skip default builders, comply. - if (op.skipDefaultBuilders()) - return; // Params with (possibly) default args. auto valueBuilderParams = llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { @@ -1010,9 +997,9 @@ static void emitValueBuilder(const Operator &op, auto lhs = *llvm::split(arg, "=").begin(); return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); }); - std::string name_without_dialect = + std::string nameWithoutDialect = op.getOperationName().substr(op.getOperationName().find('.') + 1); - os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect), + os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect), op.getCppClassName(), llvm::join(valueBuilderParams, ", "), llvm::join(opBuilderArgs, ", "), @@ -1051,11 +1038,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { if (clDialectName.empty()) llvm::PrintFatalError("dialect name not provided"); - bool isExtension = !clDialectExtensionName.empty(); - os << llvm::formatv(fileHeader, isExtension - ? clDialectExtensionName.getValue() - : clDialectName.getValue()); - if (isExtension) + os << fileHeader; + if (!clDialectExtensionName.empty()) os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); else os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());