Skip to content

Editable mode is error-ing out with flatc message #8784

@mergennachin

Description

@mergennachin

🐛 Describe the bug

Contents of model_lower.py is the following (same as this)

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from torch.export import Dim, export

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.Conv2d(1, 8, 3),
            torch.nn.ReLU(),
            torch.nn.Conv2d(8, 16, 3),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d([1,1])
        )
        self.linear = torch.nn.Linear(16, 10)

    def forward(self, x):
        y = self.seq(x)
        y = torch.flatten(y, 1)
        y = self.linear(y)
        return y

model = Model()
inputs = (torch.randn(1,1,16,16),)
dynamic_shapes = {
    "x": {
        2: Dim("h", min=16, max=1024),
        3: Dim("w", min=16, max=1024),
    }
}

exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes)
executorch_program = to_edge_transform_and_lower(
    exported_program,
    partitioner = [XnnpackPartitioner()]
).to_executorch()

with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)

and when I run ./python model_lower.py it is failing with

Traceback (most recent call last):
  File "/Users/mnachin/miniconda/envs/executorch_test/bin/flatc", line 5, in <module>
    from executorch.data.bin import flatc
ModuleNotFoundError: No module named 'executorch.data'
Traceback (most recent call last):
  File "/Users/mnachin/executorch/model_lower.py", line 34, in <module>
    executorch_program = to_edge_transform_and_lower(
  File "/Users/mnachin/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/mnachin/executorch/exir/program/_program.py", line 1107, in to_edge_transform_and_lower
    edge_manager = edge_manager.to_backend({name: curr_partitioner})
  File "/Users/mnachin/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/mnachin/executorch/exir/program/_program.py", line 1363, in to_backend
    new_edge_programs[name] = to_backend(program, partitioner[name])
  File "/Users/mnachin/miniconda/envs/executorch_test/lib/python3.10/functools.py", line 878, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/mnachin/executorch/exir/backend/backend_api.py", line 396, in _
    tagged_graph_module = _partition_and_lower(
  File "/Users/mnachin/executorch/exir/backend/backend_api.py", line 319, in _partition_and_lower
    partitioned_module = _partition_and_lower_one_graph_module(
  File "/Users/mnachin/executorch/exir/backend/backend_api.py", line 249, in _partition_and_lower_one_graph_module
    lowered_submodule = to_backend(
  File "/Users/mnachin/miniconda/envs/executorch_test/lib/python3.10/functools.py", line 878, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/mnachin/executorch/exir/backend/backend_api.py", line 113, in _
    preprocess_result: PreprocessResult = cls.preprocess(
  File "/Users/mnachin/executorch/backends/xnnpack/xnnpack_preprocess.py", line 190, in preprocess
    processed_bytes=serialize_xnnpack_binary(
  File "/Users/mnachin/executorch/backends/xnnpack/serialization/xnnpack_graph_serialize.py", line 344, in serialize_xnnpack_binary
    flatbuffer_payload = convert_to_flatbuffer(xnnpack_graph)
  File "/Users/mnachin/executorch/backends/xnnpack/serialization/xnnpack_graph_serialize.py", line 325, in convert_to_flatbuffer
    _flatc_compile(d, schema_path, json_path)
  File "/Users/mnachin/executorch/exir/_serialize/_flatbuffer.py", line 213, in _flatc_compile
    _run_flatc(
  File "/Users/mnachin/executorch/exir/_serialize/_flatbuffer.py", line 199, in _run_flatc
    subprocess.run([flatc_path] + list(args), check=True)
  File "/Users/mnachin/miniconda/envs/executorch_test/lib/python3.10/subprocess.py", line 524, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['flatc', '--binary', '-o', '/var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/tmpxmwvou0t', '/var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/tmpxmwvou0t/schema.fbs', '/var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/tmpxmwvou0t/schema.json']' returned non-zero exit status 1.

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250131
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.3.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 3.31.4
Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-15.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+0ab3499
[pip3] numpy==2.2.3
[pip3] torch==2.7.0.dev20250131
[pip3] torchao==0.10.0+git7d879462
[pip3] torchaudio==2.6.0.dev20250131
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250131
[conda] executorch 0.6.0a0+0ab3499 pypi_0 pypi
[conda] numpy 2.2.3 pypi_0 pypi
[conda] torch 2.7.0.dev20250131 pypi_0 pypi
[conda] torchao 0.10.0+git7d879462 pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250131 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250131 pypi_0 pypi

cc @byjlw

Metadata

Metadata

Assignees

Labels

module: user experienceIssues related to reducing friction for userstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

Status

Done

Status

Done

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions