|
26 | 26 | from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
|
27 | 27 | from executorch.exir.error import ExportError
|
28 | 28 | from executorch.exir.graph_module import get_control_flow_submodules
|
| 29 | +from executorch.exir.operator.convert import _pybind_schema_to_native_schema |
29 | 30 | from executorch.exir.pass_base import PassBase
|
30 | 31 | from executorch.exir.pass_manager import PassType
|
31 | 32 | from executorch.exir.passes import (
|
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
|
836 | 837 | ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
|
837 | 838 | program
|
838 | 839 | )
|
| 840 | + ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose( |
| 841 | + ops_set_to_not_decompose |
| 842 | + ) |
839 | 843 |
|
840 | 844 | for op_aten in ops_set_to_not_decompose:
|
841 | 845 | _register_no_decomp_op(op_aten)
|
@@ -965,6 +969,37 @@ def _sanity_check_graph_for_non_decomp_ops(
|
965 | 969 | logging.warning(warning_str)
|
966 | 970 |
|
967 | 971 |
|
| 972 | +def _remove_invalid_ops_for_not_decompose( |
| 973 | + ops_to_not_decompose: List[torch._ops.OpOverload], |
| 974 | +) -> List[torch._ops.OpOverload]: |
| 975 | + def keep(op): |
| 976 | + schema = op._schema |
| 977 | + native_schema = _pybind_schema_to_native_schema(schema) |
| 978 | + if native_schema.is_mutable: |
| 979 | + logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable.") |
| 980 | + return False |
| 981 | + |
| 982 | + if native_schema.aliased_return_names() != [None]: |
| 983 | + logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output.") |
| 984 | + return False |
| 985 | + |
| 986 | + # Explicit block list of ops that don't work if asked for |
| 987 | + # preservation |
| 988 | + if op in [ |
| 989 | + # Hits infinte recursion error when op is in |
| 990 | + # EDGE_DO_NOT_DECOMP namespace |
| 991 | + torch.ops.aten._to_copy.default, |
| 992 | + # type scalar->tensor type promotion does not work on op |
| 993 | + # in EDGE_DO_NOT_DECOMP namespace |
| 994 | + torch.ops.aten.mul.Tensor, |
| 995 | + ]: |
| 996 | + logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist.") |
| 997 | + return False |
| 998 | + return True |
| 999 | + |
| 1000 | + return list(filter(keep, ops_to_not_decompose)) |
| 1001 | + |
| 1002 | + |
968 | 1003 | def _gen_edge_manager_for_partitioners(
|
969 | 1004 | partitioner: Dict[str, List[Partitioner]],
|
970 | 1005 | aten_programs: Dict[str, ExportedProgram],
|
@@ -992,6 +1027,9 @@ def _gen_edge_manager_for_partitioners(
|
992 | 1027 | all_ops_no_decomp = set()
|
993 | 1028 | for curr_partitioner in partitioner.get(name, []):
|
994 | 1029 | curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
|
| 1030 | + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( |
| 1031 | + curr_ops_no_decomp |
| 1032 | + ) |
995 | 1033 | all_ops_no_decomp |= set(curr_ops_no_decomp)
|
996 | 1034 |
|
997 | 1035 | table = _default_decomposition_table()
|
@@ -1113,6 +1151,7 @@ def to_edge_transform_and_lower(
|
1113 | 1151 | curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
|
1114 | 1152 | program
|
1115 | 1153 | )
|
| 1154 | + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) |
1116 | 1155 | ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
|
1117 | 1156 | _sanity_check_graph_for_non_decomp_ops(
|
1118 | 1157 | name,
|
|
0 commit comments