Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c7d4b49

Browse files
robert-kalmarPop-korn
andcommittedJun 10, 2025·
Add aot example with Neutron Backend
Co-authored-by: Martin Pavella <[email protected]>
1 parent 77f16dc commit c7d4b49

File tree

5 files changed

+576
-0
lines changed

5 files changed

+576
-0
lines changed
 

‎examples/nxp/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# PyTorch Model Delegation to Neutron Backend
2+
3+
In this guideline we will show how to use the ExecuTorch AoT part to convert a PyTorch model to ExecuTorch format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend.
4+
5+
First we will start with an example script converting the model. This example show the CifarNet model preparation. It is the same model which is part of the `example_cifarnet`
6+
7+
The steps are expected to be executed from the executorch root folder.
8+
1. Run the setup.sh script to install the neutron-converter:
9+
```commandline
10+
$ examples/nxp/setup.sh
11+
```
12+
13+
2. Now run the `aot_neutron_compile.py` example with the `cifar10` model
14+
```commandline
15+
$ python -m examples.nxp.aot_neutron_compile --quantize \
16+
--delegate --neutron_converter_flavor SDK_25_03 -m cifar10
17+
```
18+
19+
3. It will generate you `cifar10_nxp_delegate.pte` file which can be used with the MXUXpresso SDK `cifarnet_example` project.

‎examples/nxp/aot_neutron_compile.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright 2024-2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Example script to compile the model for the NXP Neutron NPU
7+
8+
import argparse
9+
import io
10+
import logging
11+
from collections import defaultdict
12+
from typing import Iterator
13+
14+
import executorch.extension.pybindings.portable_lib
15+
import executorch.kernels.quantized # noqa F401
16+
17+
import torch
18+
19+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
20+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
21+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
22+
from executorch.examples.models import MODEL_NAME_TO_MODEL
23+
from executorch.examples.models.model_factory import EagerModelFactory
24+
25+
from executorch.exir import (
26+
EdgeCompileConfig,
27+
ExecutorchBackendConfig,
28+
to_edge_transform_and_lower,
29+
)
30+
from executorch.extension.export_util import save_pte_program
31+
32+
from torch.export import export
33+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
34+
35+
from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
36+
37+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
38+
logging.basicConfig(level=logging.INFO, format=FORMAT)
39+
40+
41+
def print_ops_in_edge_program(edge_program):
42+
"""Find all ops used in the `edge_program` and print them out along with their occurrence counts."""
43+
44+
ops_and_counts = defaultdict(
45+
lambda: 0
46+
) # Mapping ops to the numer of times they are used.
47+
for node in edge_program.graph.nodes:
48+
if "call" not in node.op:
49+
continue # `placeholder` or `output`. (not an operator)
50+
51+
if hasattr(node.target, "_schema"):
52+
# Regular op.
53+
# noinspection PyProtectedMember
54+
op = node.target._schema.schema.name
55+
else:
56+
# Builtin function.
57+
op = str(node.target)
58+
59+
ops_and_counts[op] += 1
60+
61+
# Sort the ops based on how many times they are used in the model.
62+
ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True)
63+
64+
# Print the ops and use counts.
65+
for op, count in ops_and_counts:
66+
print(f"{op: <50} {count}x")
67+
68+
69+
def get_model_and_inputs_from_name(model_name: str):
70+
"""Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
71+
72+
Raises RuntimeError if there is no example model corresponding to the given name.
73+
"""
74+
75+
calibration_inputs = None
76+
# Case 1: Model is defined in this file
77+
if model_name in models.keys():
78+
m = models[model_name]()
79+
model = m.get_eager_model()
80+
example_inputs = m.get_example_inputs()
81+
calibration_inputs = m.get_calibration_inputs(64)
82+
# Case 2: Model is defined in executorch/examples/models/
83+
elif model_name in MODEL_NAME_TO_MODEL.keys():
84+
logging.warning(
85+
"Using a model from examples/models not all of these are currently supported"
86+
)
87+
model, example_inputs, _ = EagerModelFactory.create_model(
88+
*MODEL_NAME_TO_MODEL[model_name]
89+
)
90+
else:
91+
raise RuntimeError(
92+
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
93+
)
94+
95+
return model, example_inputs, calibration_inputs
96+
97+
98+
models = {
99+
"cifar10": CifarNet,
100+
}
101+
102+
103+
def post_training_quantize(
104+
model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]]
105+
):
106+
"""Quantize the provided model.
107+
108+
:param model: Aten model to quantize.
109+
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
110+
input. Or an iterator over such tuples.
111+
"""
112+
# Based on executorch.examples.arm.aot_amr_compiler.quantize
113+
logging.info("Quantizing model")
114+
logging.debug(f"---> Original model: {model}")
115+
quantizer = NeutronQuantizer()
116+
117+
m = prepare_pt2e(model, quantizer)
118+
# Calibration:
119+
logging.debug("Calibrating model")
120+
121+
def _get_batch_size(data):
122+
return data[0].shape[0]
123+
124+
if not isinstance(
125+
calibration_inputs, tuple
126+
): # Assumption that calibration_inputs is finite.
127+
for i, data in enumerate(calibration_inputs):
128+
if i % (1000 // _get_batch_size(data)) == 0:
129+
logging.debug(f"{i * _get_batch_size(data)} calibration inputs done")
130+
m(*data)
131+
else:
132+
m(*calibration_inputs)
133+
m = convert_pt2e(m)
134+
logging.debug(f"---> Quantized model: {m}")
135+
return m
136+
137+
138+
if __name__ == "__main__": # noqa C901
139+
parser = argparse.ArgumentParser()
140+
parser.add_argument(
141+
"-m",
142+
"--model_name",
143+
required=True,
144+
help=f"Provide model name. Valid ones: {set(models.keys())}",
145+
)
146+
parser.add_argument(
147+
"-d",
148+
"--delegate",
149+
action="store_true",
150+
required=False,
151+
default=False,
152+
help="Flag for producing eIQ NeutronBackend delegated model",
153+
)
154+
parser.add_argument(
155+
"--target",
156+
required=False,
157+
default="imxrt700",
158+
help="Platform for running the delegated model",
159+
)
160+
parser.add_argument(
161+
"-c",
162+
"--neutron_converter_flavor",
163+
required=False,
164+
default="SDK_25_03",
165+
help="Flavor of installed neutron-converter module. Neutron-converter module named "
166+
"'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.",
167+
)
168+
parser.add_argument(
169+
"-q",
170+
"--quantize",
171+
action="store_true",
172+
required=False,
173+
default=False,
174+
help="Produce a quantized model",
175+
)
176+
parser.add_argument(
177+
"-s",
178+
"--so_library",
179+
required=False,
180+
default=None,
181+
help="Path to custome kernel library",
182+
)
183+
parser.add_argument(
184+
"--debug", action="store_true", help="Set the logging level to debug."
185+
)
186+
parser.add_argument(
187+
"-t",
188+
"--test",
189+
action="store_true",
190+
required=False,
191+
default=False,
192+
help="Test the selected model and print the accuracy between 0 and 1.",
193+
)
194+
parser.add_argument(
195+
"--operators_not_to_delegate",
196+
required=False,
197+
default=[],
198+
type=str,
199+
nargs="*",
200+
help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm",
201+
)
202+
203+
args = parser.parse_args()
204+
205+
if args.debug:
206+
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
207+
208+
# 1. pick model from one of the supported lists
209+
model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
210+
args.model_name
211+
)
212+
model = model.eval()
213+
214+
# 2. Export the model to ATEN
215+
exported_program = torch.export.export_for_training(
216+
model, example_inputs, strict=True
217+
)
218+
219+
module = exported_program.module()
220+
221+
# 4. Quantize if required
222+
if args.quantize:
223+
if calibration_inputs is None:
224+
logging.warning(
225+
"No calibration inputs available, using the example inputs instead"
226+
)
227+
calibration_inputs = example_inputs
228+
module = post_training_quantize(module, calibration_inputs)
229+
230+
if args.so_library is not None:
231+
logging.debug(f"Loading libraries: {args.so_library} and {args.portable_lib}")
232+
torch.ops.load_library(args.so_library)
233+
234+
if args.test:
235+
match args.model_name:
236+
case "cifar10":
237+
accuracy = test_cifarnet_model(module)
238+
239+
case _:
240+
raise NotImplementedError(
241+
f"Testing of model `{args.model_name}` is not yet supported."
242+
)
243+
244+
quantized_str = "quantized " if args.quantize else ""
245+
print(f"\nAccuracy of the {quantized_str}`{args.model_name}`: {accuracy}\n")
246+
247+
# 5. Export to edge program
248+
partitioner_list = []
249+
if args.delegate is True:
250+
partitioner_list = [
251+
NeutronPartitioner(
252+
generate_neutron_compile_spec(
253+
args.target,
254+
args.neutron_converter_flavor,
255+
operators_not_to_delegate=args.operators_not_to_delegate,
256+
)
257+
)
258+
]
259+
260+
edge_program = to_edge_transform_and_lower(
261+
export(module, example_inputs, strict=True),
262+
partitioner=partitioner_list,
263+
compile_config=EdgeCompileConfig(
264+
_check_ir_validity=False,
265+
),
266+
)
267+
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
268+
269+
# 6. Export to ExecuTorch program
270+
try:
271+
exec_prog = edge_program.to_executorch(
272+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
273+
)
274+
except RuntimeError as e:
275+
if "Missing out variants" in str(e.args[0]):
276+
raise RuntimeError(
277+
e.args[0]
278+
+ ".\nThis likely due to an external so library not being loaded. Supply a path to it with the "
279+
"--portable_lib flag."
280+
).with_traceback(e.__traceback__) from None
281+
else:
282+
raise e
283+
284+
def executorch_program_to_str(ep, verbose=False):
285+
f = io.StringIO()
286+
ep.dump_executorch_program(out=f, verbose=verbose)
287+
return f.getvalue()
288+
289+
logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")
290+
291+
# 7. Serialize to *.pte
292+
model_name = f"{args.model_name}" + (
293+
"_nxp_delegate" if args.delegate is True else ""
294+
)
295+
save_pte_program(exec_prog, model_name)
Binary file not shown.
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright 2024 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
import itertools
8+
import logging
9+
import os.path
10+
from typing import Iterator, Tuple
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
import torch.optim as optim
16+
import torchvision
17+
18+
from executorch import exir
19+
from executorch.examples.models import model_base
20+
from torchvision import transforms
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.INFO)
24+
25+
26+
class CifarNet(model_base.EagerModelBase):
27+
28+
def __init__(self, batch_size: int = 1, pth_file: str | None = None):
29+
self.batch_size = batch_size
30+
self.pth_file = pth_file or os.path.join(
31+
os.path.dirname(__file__), "cifar_net.pth"
32+
)
33+
34+
def get_eager_model(self) -> torch.nn.Module:
35+
return get_model(self.batch_size, state_dict_file=self.pth_file)
36+
37+
def get_example_inputs(self) -> Tuple[torch.Tensor]:
38+
tl = get_test_loader()
39+
ds, _ = tl.dataset[
40+
0
41+
] # Dataset returns the data and the class. We need just the data.
42+
return (ds.unsqueeze(0),)
43+
44+
def get_calibration_inputs(
45+
self, batch_size: int = 1
46+
) -> Iterator[Tuple[torch.Tensor]]:
47+
tl = get_test_loader(batch_size)
48+
49+
def _get_first(a, _):
50+
return (a,)
51+
52+
return itertools.starmap(_get_first, iter(tl))
53+
54+
55+
class CifarNetModel(nn.Module):
56+
57+
def __init__(self):
58+
super().__init__()
59+
60+
self.conv1 = nn.Conv2d(8, 32, 5)
61+
self.conv2 = nn.Conv2d(32, 32, 5)
62+
self.conv3 = nn.Conv2d(32, 64, 5)
63+
self.pool1 = nn.MaxPool2d(2, 2)
64+
self.pool2 = nn.MaxPool2d(1, 2)
65+
self.fc = nn.Linear(1024, 10)
66+
self.softmax = nn.Softmax(1)
67+
68+
def forward(self, x):
69+
70+
# Neutron Backend does not yet have passses for automated padding if number of channels does not
71+
# fit to Neutron constrains (#channels == #MAC units). So define the model explicitly tailored for Neutron-C-64.
72+
x = F.pad(x, (2, 2, 2, 2, 0, 5))
73+
x = self.conv1(x)
74+
x = self.pool1(x)
75+
76+
x = F.pad(x, (2, 2, 2, 2))
77+
x = self.conv2(x)
78+
x = self.pool1(x)
79+
80+
x = F.pad(x, (2, 2, 2, 2))
81+
x = self.conv3(x)
82+
x = self.pool2(x)
83+
84+
# The output of the previous MaxPool has shape [batch, 64, 4, 4] ([batch, 4, 4, 64] in Neutron IR). When running
85+
# inference of the `FullyConnected`, Neutron IR will automatically collapse the channels and spatial dimensions and
86+
# work with a tensor of shape [batch, 1024].
87+
# PyTorch will combine the C and H with `batch`, and leave the last dimension (W). This will result in a tensor of
88+
# shape [batch * 256, 4]. This cannot be multiplied with the weight matrix of shape [1024, 10].
89+
x = torch.reshape(x, (-1, 1024))
90+
x = self.fc(x)
91+
x = self.softmax(x)
92+
93+
return x
94+
95+
96+
def get_train_loader(batch_size: int = 1):
97+
"""Get loader for training data."""
98+
transform = transforms.Compose(
99+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
100+
)
101+
102+
train_set = torchvision.datasets.CIFAR10(
103+
root="./data", train=True, download=True, transform=transform
104+
)
105+
train_loader = torch.utils.data.DataLoader(
106+
train_set, batch_size=batch_size, shuffle=True, num_workers=0
107+
)
108+
109+
return train_loader
110+
111+
112+
def get_test_loader(batch_size: int = 1):
113+
"""Get loader for testing data."""
114+
transform = transforms.Compose(
115+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
116+
)
117+
118+
test_set = torchvision.datasets.CIFAR10(
119+
root="./data", train=False, download=True, transform=transform
120+
)
121+
test_loader = torch.utils.data.DataLoader(
122+
test_set, batch_size=batch_size, shuffle=False, num_workers=0
123+
)
124+
125+
return test_loader
126+
127+
128+
def get_model(
129+
batch_size: int = 1,
130+
state_dict_file: str | None = None,
131+
train: bool = False,
132+
num_epochs: int = 1,
133+
) -> nn.Module:
134+
"""Create the CifarNet model.
135+
136+
:param batch_size: Batch size to use during training.
137+
:param state_dict_file: `.pth` file. If provided and the file exists, weights will be loaded from it. Also after
138+
training, the weights will be stored in the file.
139+
:param train: Boolean indicating whether to train the model.
140+
:param num_epochs: Number of epochs to use during training.
141+
:return: The loaded/trained CifarNet model.
142+
"""
143+
cifar_net = CifarNetModel()
144+
145+
if state_dict_file is not None and os.path.isfile(state_dict_file):
146+
# Load the pre-trained weights.
147+
logger.info(f"Using pre-trained weights from `{state_dict_file}`.")
148+
cifar_net.load_state_dict(torch.load(state_dict_file, weights_only=True))
149+
150+
if train:
151+
# Train the model.
152+
criterion = nn.CrossEntropyLoss()
153+
optimizer = optim.SGD(cifar_net.parameters(), lr=0.0005, momentum=0.6)
154+
train_loader = get_train_loader(batch_size)
155+
156+
for epoch in range(num_epochs):
157+
running_loss = 0.0
158+
for i, data in enumerate(train_loader, 0):
159+
# get the inputs; data is a list of [inputs, labels]
160+
inputs, labels = data
161+
162+
# zero the parameter gradients
163+
optimizer.zero_grad()
164+
165+
# forward + backward + optimize
166+
outputs = cifar_net(inputs)
167+
loss = criterion(outputs, labels)
168+
loss.backward()
169+
optimizer.step()
170+
171+
# print statistics
172+
running_loss += loss.item()
173+
if i % 2000 == 1999: # print every 2000 mini-batches
174+
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
175+
running_loss = 0.0
176+
177+
logger.info("Finished training.")
178+
if state_dict_file is not None and train:
179+
logger.info(f"Saving the trained weights in `{state_dict_file}`.")
180+
torch.save(cifar_net.state_dict(), state_dict_file)
181+
182+
return cifar_net
183+
184+
185+
def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]:
186+
"""Return a tuple containing 1 tensor (for the 1 model input) and the tensor will have shape
187+
[`num_images`, 3, 32, 32].
188+
"""
189+
loader = iter(get_train_loader(1)) # The train loader shuffles the images.
190+
images = [image for image, _ in itertools.islice(loader, num_images)]
191+
tensor = torch.vstack(images)
192+
return (tensor,)
193+
194+
195+
def test_cifarnet_model(cifar_net: nn.Module, batch_size: int = 1) -> float:
196+
"""Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy.
197+
198+
This function may at some point in the future be integrated into the `CifarNet` class.
199+
200+
:param cifar_net: The model to test with the CifarNet10 testing dataset.
201+
:return: The accuracy of the model (between 0 and 1).
202+
"""
203+
correct = 0
204+
total = 0
205+
with torch.no_grad():
206+
for data in get_test_loader(batch_size):
207+
images, labels = data
208+
outputs = cifar_net(images)
209+
_, predicted = torch.max(outputs.data, 1)
210+
total += labels.size(0)
211+
correct += torch.sum(predicted == labels).item()
212+
213+
return correct / total
214+
215+
216+
if __name__ == "__main__":
217+
parser = argparse.ArgumentParser()
218+
parser.add_argument(
219+
"--pte-file",
220+
required=False,
221+
help="Name of a `.pte` file to save the trained model in.",
222+
)
223+
parser.add_argument(
224+
"--pth-file",
225+
required=False,
226+
type=str,
227+
help="Name of a `.pth` file to save the trained weights in. If it already exists, the model "
228+
"will be initialized with these weights.",
229+
)
230+
parser.add_argument(
231+
"--train", required=False, action="store_true", help="Train the model."
232+
)
233+
parser.add_argument(
234+
"--test", required=False, action="store_true", help="Test the trained model."
235+
)
236+
parser.add_argument("-b", "--batch-size", required=False, type=int, default=1)
237+
parser.add_argument("-e", "--num-epochs", required=False, type=int, default=1)
238+
args = parser.parse_args()
239+
240+
cifar_net = get_model(
241+
state_dict_file=args.pth_file,
242+
train=args.train,
243+
batch_size=args.batch_size,
244+
num_epochs=args.num_epochs,
245+
)
246+
247+
if args.test:
248+
logger.info("Running tests.")
249+
accuracy = test_cifarnet_model(cifar_net, args.batch_size)
250+
logger.info(f"Accuracy of the network on the 10000 test images: {accuracy}")
251+
252+
if args.pte_file is not None:
253+
tracing_inputs = (torch.rand(args.batch_size, 3, 32, 32),)
254+
aten_dialect_program = torch.export.export(cifar_net, tracing_inputs)
255+
edge_dialect_program: exir.EdgeProgramManager = exir.to_edge(
256+
aten_dialect_program
257+
)
258+
executorch_program = edge_dialect_program.to_executorch()
259+
260+
with open(args.pte_file, "wb") as file:
261+
logger.info(f"Saving the trained model as `{args.pte_file}`.")
262+
file.write(executorch_program.buffer)

‎examples/nxp/setup.sh

100644100755
File mode changed.

0 commit comments

Comments
 (0)
Please sign in to comment.