Skip to content

torch.split fails in to_edge #11723

Open
Open
@GregoryComer

Description

@GregoryComer
Member

🐛 Describe the bug

The torch.split operator fails with an error relating to aliasing.v

Repro:

import torch

from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig, to_edge
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.split(x, 2)

model = Model()
inputs = (
    torch.randn(6, 10),
)

print(inputs)

eager_outputs = model(*inputs)
#print(f"Eager: {eager_outputs.shape} {eager_outputs}")

ep = torch.export.export(model.eval(), inputs)
print(ep)

print(f"EP: {ep.module()(*inputs)}")

lowered = to_edge_transform_and_lower(
    ep,
    partitioner=[CoreMLPartitioner()],
    compile_config=EdgeCompileConfig(_check_ir_validity=False)
).to_executorch()

print(lowered.exported_program())

et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]

print(et_outputs)

et_outputs - eager_outputs

Output:

RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: EDGE_DO_NOT_DECOMP::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]. We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub.

While executing %split : [num_users=3] = call_function[target=torch.ops.EDGE_DO_NOT_DECOMP.split.Tensor](args = (%x, 2), kwargs = {})

Versions

executorch commit 67b6009 (Jun 14)

cc @JacobSzwejbka @angelayi

Activity

added
module: exirIssues related to Export IR and the code under exir/
backend testerThis bug was found by the backend test suite.
on Jun 16, 2025
added theissue type on Jun 16, 2025
JacobSzwejbka

JacobSzwejbka commented on Jun 27, 2025

@JacobSzwejbka
Contributor

And then the error that throws the complaint about aliasing can point people to this list as well

JacobSzwejbka

JacobSzwejbka commented on Jun 27, 2025

@JacobSzwejbka
Contributor

@angelayi is there a way would could just auto convert all of the view ops instead of maintaining this manual dictionary? Sort of like how its trivial to convert a functional schema_kind to inplace.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend testerThis bug was found by the backend test suite.good first issueGood for newcomersmodule: exirIssues related to Export IR and the code under exir/

    Type

    Projects

    Status

    To triage

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @GregoryComer@JacobSzwejbka

        Issue actions

          torch.split fails in to_edge · Issue #11723 · pytorch/executorch