Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,16 @@ def ops_to_not_decompose(
do_not_decompose = []
op_support = OperatorsSupportedForCoreMLBackend()
for node in ep.graph.nodes:
if (
node.op == "call_function"
and isinstance(node.target, torch._ops.OpOverload)
and op_support.is_node_supported(None, node)
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
do_not_decompose.append(node.target)
try:
if op_support.is_node_supported(None, node):
do_not_decompose.append(node.target)
except Exception as e:
# CoreML's op_support.is_node_supported will sometimes throw
# for unsupported ops, rather than returning False
logger.warning(
f"Encountered exception when checking node support: {e}"
)
return do_not_decompose, None
19 changes: 18 additions & 1 deletion backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,28 @@ def test_vit_skip_conv(self):

def test_ops_to_not_decompose(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, q, k, v, mask):
return torch.ops.aten.scaled_dot_product_attention.default(
out = torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=mask
)

# Add non-functional and alias ops
# These will be removed by ExecuTorch in non-decomposition
# table because they cannot be functionalized
out = out.transpose(1, 2)
out = out.view(1, -1)
out = out.permute(0, 1)
out = out.add_(1.0)
out = out.mul_(2.0)
out = out.div_(3.0)
out = out.sub_(4.0)
out = torch.ops.aten.view_copy.default(out, (-1,))
out = out.select(0, 0)
return out

model = Model()
model.eval()

Expand Down
49 changes: 49 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
from executorch.exir.pass_base import PassBase
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
Expand Down Expand Up @@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
program
)
ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose(
ops_set_to_not_decompose
)

for op_aten in ops_set_to_not_decompose:
_register_no_decomp_op(op_aten)
Expand Down Expand Up @@ -965,6 +969,47 @@ def _sanity_check_graph_for_non_decomp_ops(
logging.warning(warning_str)


def _remove_invalid_ops_for_not_decompose(
ops_to_not_decompose: List[torch._ops.OpOverload],
) -> List[torch._ops.OpOverload]:
# To address https://github.com/pytorch/executorch/issues/8781
def keep(op):
schema = op._schema
native_schema = _pybind_schema_to_native_schema(schema)
if native_schema.is_mutable:
logging.warn(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable."
)
return False

if native_schema.aliased_return_names() != [None]:
logging.warn(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
)
return False

# Explicit block list of ops that don't work if asked for
# preservation
if op in [
# Hits infinte recursion error when op is in
# EDGE_DO_NOT_DECOMP namespace
torch.ops.aten._to_copy.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps I didn't understand the full context, but why can't a partitioner just not specify these ops for preservation? I don't think they'll be decomposed anyways?

Copy link
Contributor Author

@metascroy metascroy Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if a partitioner requests these for preservation, they are no longer aten ops (they become custom ops in EDGE_DO_NOT_DECOMP namespace), and that runs into issues in during export because PyTorch has more restrictions on custom ops than aten ops.

# scalar to tensor type promotion does not work on ops
# in EDGE_DO_NOT_DECOMP namespace
torch.ops.aten.mul.Tensor,
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.div.Tensor,
]:
logging.warn(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist."
)
return False
return True

return list(filter(keep, ops_to_not_decompose))


def _gen_edge_manager_for_partitioners(
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
Expand Down Expand Up @@ -992,6 +1037,9 @@ def _gen_edge_manager_for_partitioners(
all_ops_no_decomp = set()
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
curr_ops_no_decomp
)
all_ops_no_decomp |= set(curr_ops_no_decomp)

table = _default_decomposition_table()
Expand Down Expand Up @@ -1113,6 +1161,7 @@ def to_edge_transform_and_lower(
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
program
)
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
_sanity_check_graph_for_non_decomp_ops(
name,
Expand Down
Loading