Skip to content

Commit e5ed112

Browse files
Add aot example with Neutron Backend
Co-authored-by: Martin Pavella <[email protected]>
1 parent f8e7264 commit e5ed112

File tree

4 files changed

+612
-0
lines changed

4 files changed

+612
-0
lines changed

examples/nxp/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# PyTorch Model Delegation to Neutron Backend
2+
3+
In this guideline we will show how to use the ExecuTorch AoT part to convert a PyTorch model to ExecuTorch format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend.
4+
5+
First we will start with an example script converting the model. This example show the CifarNet model preparation. It is the same model which is part of the `example_cifarnet`
6+
7+
The steps are expected to be executed from the executorch root folder.
8+
1. After building the ExecuTorch you shall have the `libquantized_ops_aot_lib.so` and `_portable_lib.<python_version>.so` located in the `pip_out/lib` folder. We will need this library when generating the quantized cifarnet ExecuTorch model. So as first step we will find it:
9+
```commandline
10+
$ find . -name "libquantized_ops_aot_lib.so"
11+
./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/kernels/quantized/libquantized_ops_aot_lib.so
12+
13+
$ find . -name "_portable_lib.cpython-310d-x86_64-linux-gnu.so"
14+
./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/extension/pybindings/_portable_lib.cpython-310d-x86_64-linux-gnu.so
15+
```
16+
17+
2. Now run the `aot_neutron_compile.py` example with the `cifar10` model
18+
```commandline
19+
$ python examples/nxp/aot_neutron_compile.py \
20+
--quantize --so_library ./pip-out/lib.linux-x86_64-cpython-310/executorch/kernels/quantized/libquantized_ops_aot_lib.so \
21+
--delegate --neutron_converter_flavor SDK_25_03 -m cifar10
22+
```
23+
24+
3. It will generate you `cifar10_nxp_delegate.pte` file which can be used with the MXUXpresso SDK `cifarnet_example` project.

examples/nxp/aot_neutron_compile.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright 2024-2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Example script to compile the model for the NXP Neutron NPU
7+
8+
import argparse
9+
import io
10+
import logging
11+
import sys
12+
from collections import defaultdict
13+
from typing import Iterator
14+
15+
import torch
16+
17+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
18+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
19+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
20+
from executorch.examples.models import MODEL_NAME_TO_MODEL
21+
from executorch.examples.models.model_factory import EagerModelFactory
22+
23+
from executorch.exir import (
24+
EdgeCompileConfig,
25+
ExecutorchBackendConfig,
26+
to_edge_transform_and_lower,
27+
)
28+
from executorch.extension.export_util import save_pte_program
29+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
30+
from torch.export import export
31+
32+
from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
33+
34+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
35+
logging.basicConfig(level=logging.INFO, format=FORMAT)
36+
37+
38+
def print_ops_in_edge_program(edge_program):
39+
"""Find all ops used in the `edge_program` and print them out along with their occurrence counts."""
40+
41+
ops_and_counts = defaultdict(
42+
lambda: 0
43+
) # Mapping ops to the numer of times they are used.
44+
for node in edge_program.graph.nodes:
45+
if "call" not in node.op:
46+
continue # `placeholder` or `output`. (not an operator)
47+
48+
if hasattr(node.target, "_schema"):
49+
# Regular op.
50+
# noinspection PyProtectedMember
51+
op = node.target._schema.schema.name
52+
else:
53+
# Builtin function.
54+
op = str(node.target)
55+
56+
ops_and_counts[op] += 1
57+
58+
# Sort the ops based on how many times they are used in the model.
59+
ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True)
60+
61+
# Print the ops and use counts.
62+
for op, count in ops_and_counts:
63+
print(f"{op: <50} {count}x")
64+
65+
66+
def get_model_and_inputs_from_name(model_name: str):
67+
"""Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
68+
69+
Raises RuntimeError if there is no example model corresponding to the given name.
70+
"""
71+
72+
calibration_inputs = None
73+
# Case 1: Model is defined in this file
74+
if model_name in models.keys():
75+
m = models[model_name]()
76+
model = m.get_eager_model()
77+
example_inputs = m.get_example_inputs()
78+
calibration_inputs = m.get_calibration_inputs(64)
79+
# Case 2: Model is defined in executorch/examples/models/
80+
elif model_name in MODEL_NAME_TO_MODEL.keys():
81+
logging.warning(
82+
"Using a model from examples/models not all of these are currently supported"
83+
)
84+
model, example_inputs, _ = EagerModelFactory.create_model(
85+
*MODEL_NAME_TO_MODEL[model_name]
86+
)
87+
else:
88+
raise RuntimeError(
89+
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
90+
)
91+
92+
return model, example_inputs, calibration_inputs
93+
94+
95+
models = {
96+
"cifar10": CifarNet,
97+
}
98+
99+
100+
def post_training_quantize(
101+
model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]]
102+
):
103+
"""Quantize the provided model.
104+
105+
:param model: Aten model to quantize.
106+
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
107+
input. Or an iterator over such tuples.
108+
"""
109+
# Based on executorch.examples.arm.aot_amr_compiler.quantize
110+
logging.info("Quantizing model")
111+
logging.debug(f"---> Original model: {model}")
112+
quantizer = NeutronQuantizer()
113+
114+
m = prepare_pt2e(model, quantizer)
115+
# Calibration:
116+
logging.debug("Calibrating model")
117+
118+
def _get_batch_size(data):
119+
return data[0].shape[0]
120+
121+
if not isinstance(
122+
calibration_inputs, tuple
123+
): # Assumption that calibration_inputs is finite.
124+
# get_batch_size = lambda data: data[0].shape[0]
125+
for i, data in enumerate(calibration_inputs):
126+
if i % (1000 // _get_batch_size(data)) == 0:
127+
logging.debug(f"{i * _get_batch_size(data)} calibration inputs done")
128+
m(*data)
129+
else:
130+
m(*calibration_inputs)
131+
m = convert_pt2e(m)
132+
logging.debug(f"---> Quantized model: {m}")
133+
return m
134+
135+
136+
if __name__ == "__main__": # noqa C901
137+
parser = argparse.ArgumentParser()
138+
parser.add_argument(
139+
"-m",
140+
"--model_name",
141+
required=True,
142+
help=f"Provide model name. Valid ones: {set(list(models.keys()) + list(MODEL_NAME_TO_MODEL.keys()))}",
143+
)
144+
parser.add_argument(
145+
"-d",
146+
"--delegate",
147+
action="store_true",
148+
required=False,
149+
default=False,
150+
help="Flag for producing ArmBackend delegated model",
151+
)
152+
parser.add_argument(
153+
"--target",
154+
required=False,
155+
default="imxrt700",
156+
help="Platform for running the delegated model",
157+
)
158+
parser.add_argument(
159+
"-c",
160+
"--neutron_converter_flavor",
161+
required=False,
162+
default="SDK_25_03",
163+
help="Flavor of installed neutron-converter module. Neutron-converter module named "
164+
"'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.",
165+
)
166+
parser.add_argument(
167+
"-q",
168+
"--quantize",
169+
action="store_true",
170+
required=False,
171+
default=False,
172+
help="Produce a quantized model",
173+
)
174+
parser.add_argument(
175+
"-s",
176+
"--so_library",
177+
required=False,
178+
default=None,
179+
help="Provide path to so library. E.g., cmake-out/kernels/quantized/libquantized_ops_aot_lib.so. "
180+
"To build it update the CMake arguments: -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
181+
" -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON and build the quantized_ops_aot_lib package",
182+
)
183+
parser.add_argument(
184+
"-p",
185+
"--portable_lib",
186+
required=True,
187+
default=None,
188+
help="Provide path to portable_lib so library. Should be in cmake-out/_portable_lib.cpython-310-x86_64-linux-gnu.so",
189+
)
190+
parser.add_argument(
191+
"--debug", action="store_true", help="Set the logging level to debug."
192+
)
193+
parser.add_argument(
194+
"-t",
195+
"--test",
196+
action="store_true",
197+
required=False,
198+
default=False,
199+
help="Test the selected model and print the accuracy between 0 and 1.",
200+
)
201+
parser.add_argument(
202+
"--operators_not_to_delegate",
203+
required=False,
204+
default=[],
205+
type=str,
206+
nargs="*",
207+
help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm",
208+
)
209+
210+
args = parser.parse_args()
211+
212+
if args.debug:
213+
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
214+
215+
# 1. pick model from one of the supported lists
216+
model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
217+
args.model_name
218+
)
219+
model = model.eval()
220+
221+
# 2. Export the model to ATEN
222+
exported_program = torch.export.export_for_training(
223+
model, example_inputs, strict=True
224+
)
225+
226+
# TODO: Add Neutron ATen Passes, once https://github.com/pytorch/executorch/pull/10579 is merged
227+
228+
module = exported_program.module()
229+
230+
# 4. Quantize if required
231+
if args.quantize:
232+
if args.quantize and (not args.so_library or not args.portable_lib):
233+
logging.error(
234+
"Quantization enabled without supplying path to libcustom_ops_aot_lib using --so_library and "
235+
"_portable_lib.cpython* using --portable_lib CLI options. \n"
236+
"This is required for running quantized models with unquantized input."
237+
)
238+
sys.exit(-1)
239+
if calibration_inputs is None:
240+
logging.warning(
241+
"No calibration inputs available, using the example inputs instead"
242+
)
243+
calibration_inputs = example_inputs
244+
module = post_training_quantize(module, calibration_inputs)
245+
246+
# For quantization we need to build the quantized_ops_aot_lib.so and _portable_lib.*.so
247+
# Use this CMake options
248+
# -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON
249+
# -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON
250+
# and build the quantized_ops_aot_lib package
251+
# Then run with --so_library <path_to_quantized_ops_aot_lib> --portable_lib <path_to_portable_lib>
252+
if args.so_library is not None:
253+
logging.debug(
254+
f"Loading libraries: {args.so_library} and {args.portable_lib}"
255+
)
256+
torch.ops.load_library(args.portable_lib)
257+
torch.ops.load_library(args.so_library)
258+
259+
if args.test:
260+
match args.model_name:
261+
case "cifar10":
262+
accuracy = test_cifarnet_model(module)
263+
264+
case _:
265+
raise NotImplementedError(
266+
f"Testing of model `{args.model_name}` is not yet supported."
267+
)
268+
269+
cyan, end_format = "\033[96m", "\033[0m"
270+
quantized_str = "quantized " if args.quantize else ""
271+
print(
272+
f"\n{cyan}Accuracy of the {quantized_str}`{args.model_name}`: {accuracy}{end_format}\n"
273+
)
274+
275+
# 5. Export to edge program
276+
partitioner_list = []
277+
if args.delegate is True:
278+
partitioner_list = [
279+
NeutronPartitioner(
280+
generate_neutron_compile_spec(
281+
args.target,
282+
args.neutron_converter_flavor,
283+
operators_not_to_delegate=args.operators_not_to_delegate,
284+
)
285+
)
286+
]
287+
288+
edge_program = to_edge_transform_and_lower(
289+
export(module, example_inputs, strict=True),
290+
partitioner=partitioner_list,
291+
compile_config=EdgeCompileConfig(
292+
_check_ir_validity=False,
293+
),
294+
)
295+
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
296+
297+
# 6. Export to ExecuTorch program
298+
try:
299+
exec_prog = edge_program.to_executorch(
300+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
301+
)
302+
except RuntimeError as e:
303+
if "Missing out variants" in str(e.args[0]):
304+
raise RuntimeError(
305+
e.args[0]
306+
+ ".\nThis likely due to an external so library not being loaded. Supply a path to it with the "
307+
"--portable_lib flag."
308+
).with_traceback(e.__traceback__) from None
309+
else:
310+
raise e
311+
312+
def executorch_program_to_str(ep, verbose=False):
313+
f = io.StringIO()
314+
ep.dump_executorch_program(out=f, verbose=verbose)
315+
return f.getvalue()
316+
317+
logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")
318+
319+
# 7. Serialize to *.pte
320+
model_name = f"{args.model_name}" + (
321+
"_nxp_delegate" if args.delegate is True else ""
322+
)
323+
save_pte_program(exec_prog, model_name)
Binary file not shown.

0 commit comments

Comments
 (0)