Skip to content

Commit 397a940

Browse files
metascroyfacebook-github-bot
authored andcommitted
Fixes to_edge_transform_and_lower when unsupported ops are asked for preservation (#8776)
Summary: If a partitioner requests to_edge_transform_and_lower keep mutable / aliasing ops (e.g., transpose, view, permute, etc), lowering with ExecuTorch fails because those ops cannot be functionalized when wrapped in the EDGE_DO_NOT_DECOMP namespace as custom ops. This PR filters out unsupported ops that backends request for preservation. Differential Revision: D70333876 Pulled By: metascroy
1 parent 0ab3499 commit 397a940

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,16 @@ def ops_to_not_decompose(
111111
do_not_decompose = []
112112
op_support = OperatorsSupportedForCoreMLBackend()
113113
for node in ep.graph.nodes:
114-
if (
115-
node.op == "call_function"
116-
and isinstance(node.target, torch._ops.OpOverload)
117-
and op_support.is_node_supported(None, node)
114+
if node.op == "call_function" and isinstance(
115+
node.target, torch._ops.OpOverload
118116
):
119-
do_not_decompose.append(node.target)
117+
try:
118+
if op_support.is_node_supported(None, node):
119+
do_not_decompose.append(node.target)
120+
except Exception as e:
121+
# CoreML's op_support.is_node_supported will sometimes throw
122+
# for unsupported ops, rather than returning False
123+
logger.warning(
124+
f"Encountered exception when checking node support: {e}"
125+
)
120126
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,28 @@ def test_vit_skip_conv(self):
8282

8383
def test_ops_to_not_decompose(self):
8484
class Model(torch.nn.Module):
85+
def __init__(self) -> None:
86+
super().__init__()
87+
88+
buffer = torch.ones(1)
89+
self.register_buffer("buffer", buffer)
90+
8591
def forward(self, q, k, v, mask):
86-
return torch.ops.aten.scaled_dot_product_attention.default(
92+
out = torch.ops.aten.scaled_dot_product_attention.default(
8793
q, k, v, attn_mask=mask
8894
)
8995

96+
# Add non-functional and alias ops
97+
# These will be removed by ExecuTorch in non-decomposition
98+
# table because they cannot be functionalized
99+
out = out.transpose(1, 2)
100+
out = out.view(1, -1)
101+
out = out.permute(0, 1)
102+
out = out.add_(self.buffer)
103+
out = torch.ops.aten.view_copy.default(out, (-1,))
104+
out = out.select(0, 0)
105+
return out
106+
90107
model = Model()
91108
model.eval()
92109

@@ -107,6 +124,9 @@ def forward(self, q, k, v, mask):
107124
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
108125
ep, partitioner=[coreml_partitioner]
109126
)
127+
print(
128+
format_delegated_graph(edge_program_manager.exported_program().graph_module)
129+
)
110130
self.assertTrue(
111131
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
112132
in format_delegated_graph(

exir/program/_program.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
2727
from executorch.exir.error import ExportError
2828
from executorch.exir.graph_module import get_control_flow_submodules
29+
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
2930
from executorch.exir.pass_base import PassBase
3031
from executorch.exir.pass_manager import PassType
3132
from executorch.exir.passes import (
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
836837
ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
837838
program
838839
)
840+
ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose(
841+
ops_set_to_not_decompose
842+
)
839843

840844
for op_aten in ops_set_to_not_decompose:
841845
_register_no_decomp_op(op_aten)
@@ -965,6 +969,37 @@ def _sanity_check_graph_for_non_decomp_ops(
965969
logging.warning(warning_str)
966970

967971

972+
def _remove_invalid_ops_for_not_decompose(
973+
ops_to_not_decompose: List[torch._ops.OpOverload],
974+
) -> List[torch._ops.OpOverload]:
975+
def keep(op):
976+
schema = op._schema
977+
native_schema = _pybind_schema_to_native_schema(schema)
978+
if native_schema.is_mutable:
979+
logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable.")
980+
return False
981+
982+
if native_schema.aliased_return_names() != [None]:
983+
logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output.")
984+
return False
985+
986+
# Explicit block list of ops that don't work if asked for
987+
# preservation
988+
if op in [
989+
# Hits infinte recursion error when op is in
990+
# EDGE_DO_NOT_DECOMP namespace
991+
torch.ops.aten._to_copy.default,
992+
# type scalar->tensor type promotion does not work on op
993+
# in EDGE_DO_NOT_DECOMP namespace
994+
torch.ops.aten.mul.Tensor,
995+
]:
996+
logging.warn(f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist.")
997+
return False
998+
return True
999+
1000+
return list(filter(keep, ops_to_not_decompose))
1001+
1002+
9681003
def _gen_edge_manager_for_partitioners(
9691004
partitioner: Dict[str, List[Partitioner]],
9701005
aten_programs: Dict[str, ExportedProgram],
@@ -992,6 +1027,9 @@ def _gen_edge_manager_for_partitioners(
9921027
all_ops_no_decomp = set()
9931028
for curr_partitioner in partitioner.get(name, []):
9941029
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1030+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1031+
curr_ops_no_decomp
1032+
)
9951033
all_ops_no_decomp |= set(curr_ops_no_decomp)
9961034

9971035
table = _default_decomposition_table()
@@ -1113,6 +1151,7 @@ def to_edge_transform_and_lower(
11131151
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
11141152
program
11151153
)
1154+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
11161155
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
11171156
_sanity_check_graph_for_non_decomp_ops(
11181157
name,

0 commit comments

Comments
 (0)