Open
Description
🐛 Describe the bug
The torch.split operator fails with an error relating to aliasing.v
Repro:
import torch
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig, to_edge
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.split(x, 2)
model = Model()
inputs = (
torch.randn(6, 10),
)
print(inputs)
eager_outputs = model(*inputs)
#print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model.eval(), inputs)
print(ep)
print(f"EP: {ep.module()(*inputs)}")
lowered = to_edge_transform_and_lower(
ep,
partitioner=[CoreMLPartitioner()],
compile_config=EdgeCompileConfig(_check_ir_validity=False)
).to_executorch()
print(lowered.exported_program())
et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]
print(et_outputs)
et_outputs - eager_outputs
Output:
RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: EDGE_DO_NOT_DECOMP::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]. We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub.
While executing %split : [num_users=3] = call_function[target=torch.ops.EDGE_DO_NOT_DECOMP.split.Tensor](args = (%x, 2), kwargs = {})
Versions
executorch commit 67b6009 (Jun 14)
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
No branches or pull requests
Activity
JacobSzwejbka commentedon Jun 27, 2025
Can you try adding split and split_copy to https://github.com/pytorch/executorch/blob/main/exir/passes/replace_broken_ops_with_function_ops_pass.py#L17
JacobSzwejbka commentedon Jun 27, 2025
And then the error that throws the complaint about aliasing can point people to this list as well
JacobSzwejbka commentedon Jun 27, 2025
@angelayi is there a way would could just auto convert all of the view ops instead of maintaining this manual dictionary? Sort of like how its trivial to convert a functional schema_kind to inplace.