diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index bdab21af3da..1c4a13ece25 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -800,3 +800,6 @@ jobs: # Run pytest PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh + + # Run aot example: + PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh diff --git a/examples/nxp/README.md b/examples/nxp/README.md new file mode 100644 index 00000000000..66ca0785b4c --- /dev/null +++ b/examples/nxp/README.md @@ -0,0 +1,20 @@ +# PyTorch Model Delegation to Neutron Backend + +In this guide we will show how to use the ExecuTorch AoT flow to convert a PyTorch model to ExecuTorch format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend. + +First we will start with an example script converting the model. This example show the CifarNet model preparation. It is the same model which is part of the `example_cifarnet` + +The steps are expected to be executed from the executorch root folder. +1. Run the setup.sh script to install the neutron-converter: +```commandline +$ examples/nxp/setup.sh +``` + +2. Now run the `aot_neutron_compile.py` example with the `cifar10` model +```commandline +$ python -m examples.nxp.aot_neutron_compile --quantize \ + --delegate --neutron_converter_flavor SDK_25_03 -m cifar10 +``` + +3. It will generate you `cifar10_nxp_delegate.pte` file which can be used with the MXUXpresso SDK `cifarnet_example` project, presented [here](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/example_applications.html#how-to-build-and-run-executorch-cifarnet-example). +To get the MCUXpresso SDK follow this [guide](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/getting_mcuxpresso.html), use the MCUXpresso SDK v25.03.00. \ No newline at end of file diff --git a/examples/nxp/aot_neutron_compile.py b/examples/nxp/aot_neutron_compile.py new file mode 100644 index 00000000000..d8e4d324de2 --- /dev/null +++ b/examples/nxp/aot_neutron_compile.py @@ -0,0 +1,295 @@ +# Copyright 2024-2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script to compile the model for the NXP Neutron NPU + +import argparse +import io +import logging +from collections import defaultdict +from typing import Iterator + +import executorch.extension.pybindings.portable_lib +import executorch.kernels.quantized # noqa F401 + +import torch + +from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner +from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory + +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util import save_pte_program + +from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def print_ops_in_edge_program(edge_program): + """Find all ops used in the `edge_program` and print them out along with their occurrence counts.""" + + ops_and_counts = defaultdict( + lambda: 0 + ) # Mapping ops to the numer of times they are used. + for node in edge_program.graph.nodes: + if "call" not in node.op: + continue # `placeholder` or `output`. (not an operator) + + if hasattr(node.target, "_schema"): + # Regular op. + # noinspection PyProtectedMember + op = node.target._schema.schema.name + else: + # Builtin function. + op = str(node.target) + + ops_and_counts[op] += 1 + + # Sort the ops based on how many times they are used in the model. + ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True) + + # Print the ops and use counts. + for op, count in ops_and_counts: + print(f"{op: <50} {count}x") + + +def get_model_and_inputs_from_name(model_name: str): + """Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None) + + Raises RuntimeError if there is no example model corresponding to the given name. + """ + + calibration_inputs = None + # Case 1: Model is defined in this file + if model_name in models.keys(): + m = models[model_name]() + model = m.get_eager_model() + example_inputs = m.get_example_inputs() + calibration_inputs = m.get_calibration_inputs(64) + # Case 2: Model is defined in executorch/examples/models/ + elif model_name in MODEL_NAME_TO_MODEL.keys(): + logging.warning( + "Using a model from examples/models not all of these are currently supported" + ) + model, example_inputs, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[model_name] + ) + else: + raise RuntimeError( + f"Model '{model_name}' is not a valid name. Use --help for a list of available models." + ) + + return model, example_inputs, calibration_inputs + + +models = { + "cifar10": CifarNet, +} + + +def post_training_quantize( + model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]] +): + """Quantize the provided model. + + :param model: Aten model to quantize. + :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model + input. Or an iterator over such tuples. + """ + # Based on executorch.examples.arm.aot_amr_compiler.quantize + logging.info("Quantizing model") + logging.debug(f"---> Original model: {model}") + quantizer = NeutronQuantizer() + + m = prepare_pt2e(model, quantizer) + # Calibration: + logging.debug("Calibrating model") + + def _get_batch_size(data): + return data[0].shape[0] + + if not isinstance( + calibration_inputs, tuple + ): # Assumption that calibration_inputs is finite. + for i, data in enumerate(calibration_inputs): + if i % (1000 // _get_batch_size(data)) == 0: + logging.debug(f"{i * _get_batch_size(data)} calibration inputs done") + m(*data) + else: + m(*calibration_inputs) + m = convert_pt2e(m) + logging.debug(f"---> Quantized model: {m}") + return m + + +if __name__ == "__main__": # noqa C901 + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Provide model name. Valid ones: {set(models.keys())}", + ) + parser.add_argument( + "-d", + "--delegate", + action="store_true", + required=False, + default=False, + help="Flag for producing eIQ NeutronBackend delegated model", + ) + parser.add_argument( + "--target", + required=False, + default="imxrt700", + help="Platform for running the delegated model", + ) + parser.add_argument( + "-c", + "--neutron_converter_flavor", + required=False, + default="SDK_25_03", + help="Flavor of installed neutron-converter module. Neutron-converter module named " + "'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.", + ) + parser.add_argument( + "-q", + "--quantize", + action="store_true", + required=False, + default=False, + help="Produce a quantized model", + ) + parser.add_argument( + "-s", + "--so_library", + required=False, + default=None, + help="Path to custome kernel library", + ) + parser.add_argument( + "--debug", action="store_true", help="Set the logging level to debug." + ) + parser.add_argument( + "-t", + "--test", + action="store_true", + required=False, + default=False, + help="Test the selected model and print the accuracy between 0 and 1.", + ) + parser.add_argument( + "--operators_not_to_delegate", + required=False, + default=[], + type=str, + nargs="*", + help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm", + ) + + args = parser.parse_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True) + + # 1. pick model from one of the supported lists + model, example_inputs, calibration_inputs = get_model_and_inputs_from_name( + args.model_name + ) + model = model.eval() + + # 2. Export the model to ATEN + exported_program = torch.export.export_for_training( + model, example_inputs, strict=True + ) + + module = exported_program.module() + + # 4. Quantize if required + if args.quantize: + if calibration_inputs is None: + logging.warning( + "No calibration inputs available, using the example inputs instead" + ) + calibration_inputs = example_inputs + module = post_training_quantize(module, calibration_inputs) + + if args.so_library is not None: + logging.debug(f"Loading libraries: {args.so_library} and {args.portable_lib}") + torch.ops.load_library(args.so_library) + + if args.test: + match args.model_name: + case "cifar10": + accuracy = test_cifarnet_model(module) + + case _: + raise NotImplementedError( + f"Testing of model `{args.model_name}` is not yet supported." + ) + + quantized_str = "quantized " if args.quantize else "" + print(f"\nAccuracy of the {quantized_str}`{args.model_name}`: {accuracy}\n") + + # 5. Export to edge program + partitioner_list = [] + if args.delegate is True: + partitioner_list = [ + NeutronPartitioner( + generate_neutron_compile_spec( + args.target, + args.neutron_converter_flavor, + operators_not_to_delegate=args.operators_not_to_delegate, + ) + ) + ] + + edge_program = to_edge_transform_and_lower( + export(module, example_inputs, strict=True), + partitioner=partitioner_list, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}") + + # 6. Export to ExecuTorch program + try: + exec_prog = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + except RuntimeError as e: + if "Missing out variants" in str(e.args[0]): + raise RuntimeError( + e.args[0] + + ".\nThis likely due to an external so library not being loaded. Supply a path to it with the " + "--portable_lib flag." + ).with_traceback(e.__traceback__) from None + else: + raise e + + def executorch_program_to_str(ep, verbose=False): + f = io.StringIO() + ep.dump_executorch_program(out=f, verbose=verbose) + return f.getvalue() + + logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}") + + # 7. Serialize to *.pte + model_name = f"{args.model_name}" + ( + "_nxp_delegate" if args.delegate is True else "" + ) + save_pte_program(exec_prog, model_name) diff --git a/examples/nxp/experimental/cifar_net/cifar_net.pth b/examples/nxp/experimental/cifar_net/cifar_net.pth new file mode 100644 index 00000000000..6dc4efde21d Binary files /dev/null and b/examples/nxp/experimental/cifar_net/cifar_net.pth differ diff --git a/examples/nxp/experimental/cifar_net/cifar_net.py b/examples/nxp/experimental/cifar_net/cifar_net.py new file mode 100644 index 00000000000..1378d00cf12 --- /dev/null +++ b/examples/nxp/experimental/cifar_net/cifar_net.py @@ -0,0 +1,262 @@ +# Copyright 2024 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import itertools +import logging +import os.path +from typing import Iterator, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision + +from executorch import exir +from executorch.examples.models import model_base +from torchvision import transforms + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class CifarNet(model_base.EagerModelBase): + + def __init__(self, batch_size: int = 1, pth_file: str | None = None): + self.batch_size = batch_size + self.pth_file = pth_file or os.path.join( + os.path.dirname(__file__), "cifar_net.pth" + ) + + def get_eager_model(self) -> torch.nn.Module: + return get_model(self.batch_size, state_dict_file=self.pth_file) + + def get_example_inputs(self) -> Tuple[torch.Tensor]: + tl = get_test_loader() + ds, _ = tl.dataset[ + 0 + ] # Dataset returns the data and the class. We need just the data. + return (ds.unsqueeze(0),) + + def get_calibration_inputs( + self, batch_size: int = 1 + ) -> Iterator[Tuple[torch.Tensor]]: + tl = get_test_loader(batch_size) + + def _get_first(a, _): + return (a,) + + return itertools.starmap(_get_first, iter(tl)) + + +class CifarNetModel(nn.Module): + + def __init__(self): + super().__init__() + + self.conv1 = nn.Conv2d(8, 32, 5) + self.conv2 = nn.Conv2d(32, 32, 5) + self.conv3 = nn.Conv2d(32, 64, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.pool2 = nn.MaxPool2d(1, 2) + self.fc = nn.Linear(1024, 10) + self.softmax = nn.Softmax(1) + + def forward(self, x): + + # Neutron Backend does not yet have passses for automated padding if number of channels does not + # fit to Neutron constrains (#channels == #MAC units). So define the model explicitly tailored for Neutron-C-64. + x = F.pad(x, (2, 2, 2, 2, 0, 5)) + x = self.conv1(x) + x = self.pool1(x) + + x = F.pad(x, (2, 2, 2, 2)) + x = self.conv2(x) + x = self.pool1(x) + + x = F.pad(x, (2, 2, 2, 2)) + x = self.conv3(x) + x = self.pool2(x) + + # The output of the previous MaxPool has shape [batch, 64, 4, 4] ([batch, 4, 4, 64] in Neutron IR). When running + # inference of the `FullyConnected`, Neutron IR will automatically collapse the channels and spatial dimensions and + # work with a tensor of shape [batch, 1024]. + # PyTorch will combine the C and H with `batch`, and leave the last dimension (W). This will result in a tensor of + # shape [batch * 256, 4]. This cannot be multiplied with the weight matrix of shape [1024, 10]. + x = torch.reshape(x, (-1, 1024)) + x = self.fc(x) + x = self.softmax(x) + + return x + + +def get_train_loader(batch_size: int = 1): + """Get loader for training data.""" + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + train_set = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + train_loader = torch.utils.data.DataLoader( + train_set, batch_size=batch_size, shuffle=True, num_workers=0 + ) + + return train_loader + + +def get_test_loader(batch_size: int = 1): + """Get loader for testing data.""" + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + test_set = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=batch_size, shuffle=False, num_workers=0 + ) + + return test_loader + + +def get_model( + batch_size: int = 1, + state_dict_file: str | None = None, + train: bool = False, + num_epochs: int = 1, +) -> nn.Module: + """Create the CifarNet model. + + :param batch_size: Batch size to use during training. + :param state_dict_file: `.pth` file. If provided and the file exists, weights will be loaded from it. Also after + training, the weights will be stored in the file. + :param train: Boolean indicating whether to train the model. + :param num_epochs: Number of epochs to use during training. + :return: The loaded/trained CifarNet model. + """ + cifar_net = CifarNetModel() + + if state_dict_file is not None and os.path.isfile(state_dict_file): + # Load the pre-trained weights. + logger.info(f"Using pre-trained weights from `{state_dict_file}`.") + cifar_net.load_state_dict(torch.load(state_dict_file, weights_only=True)) + + if train: + # Train the model. + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(cifar_net.parameters(), lr=0.0005, momentum=0.6) + train_loader = get_train_loader(batch_size) + + for epoch in range(num_epochs): + running_loss = 0.0 + for i, data in enumerate(train_loader, 0): + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = cifar_net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + logger.info("Finished training.") + if state_dict_file is not None and train: + logger.info(f"Saving the trained weights in `{state_dict_file}`.") + torch.save(cifar_net.state_dict(), state_dict_file) + + return cifar_net + + +def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]: + """Return a tuple containing 1 tensor (for the 1 model input) and the tensor will have shape + [`num_images`, 3, 32, 32]. + """ + loader = iter(get_train_loader(1)) # The train loader shuffles the images. + images = [image for image, _ in itertools.islice(loader, num_images)] + tensor = torch.vstack(images) + return (tensor,) + + +def test_cifarnet_model(cifar_net: nn.Module, batch_size: int = 1) -> float: + """Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy. + + This function may at some point in the future be integrated into the `CifarNet` class. + + :param cifar_net: The model to test with the CifarNet10 testing dataset. + :return: The accuracy of the model (between 0 and 1). + """ + correct = 0 + total = 0 + with torch.no_grad(): + for data in get_test_loader(batch_size): + images, labels = data + outputs = cifar_net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += torch.sum(predicted == labels).item() + + return correct / total + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pte-file", + required=False, + help="Name of a `.pte` file to save the trained model in.", + ) + parser.add_argument( + "--pth-file", + required=False, + type=str, + help="Name of a `.pth` file to save the trained weights in. If it already exists, the model " + "will be initialized with these weights.", + ) + parser.add_argument( + "--train", required=False, action="store_true", help="Train the model." + ) + parser.add_argument( + "--test", required=False, action="store_true", help="Test the trained model." + ) + parser.add_argument("-b", "--batch-size", required=False, type=int, default=1) + parser.add_argument("-e", "--num-epochs", required=False, type=int, default=1) + args = parser.parse_args() + + cifar_net = get_model( + state_dict_file=args.pth_file, + train=args.train, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + ) + + if args.test: + logger.info("Running tests.") + accuracy = test_cifarnet_model(cifar_net, args.batch_size) + logger.info(f"Accuracy of the network on the 10000 test images: {accuracy}") + + if args.pte_file is not None: + tracing_inputs = (torch.rand(args.batch_size, 3, 32, 32),) + aten_dialect_program = torch.export.export(cifar_net, tracing_inputs) + edge_dialect_program: exir.EdgeProgramManager = exir.to_edge( + aten_dialect_program + ) + executorch_program = edge_dialect_program.to_executorch() + + with open(args.pte_file, "wb") as file: + logger.info(f"Saving the trained model as `{args.pte_file}`.") + file.write(executorch_program.buffer) diff --git a/examples/nxp/run_aot_example.sh b/examples/nxp/run_aot_example.sh new file mode 100755 index 00000000000..1710490f6d7 --- /dev/null +++ b/examples/nxp/run_aot_example.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +set -eux + +SCRIPT_DIR=$(dirname $(readlink -fm $0)) +EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +cd $EXECUTORCH_DIR + +# Run the AoT example +python -m examples.nxp.aot_neutron_compile --quantize \ + --delegate --neutron_converter_flavor SDK_25_03 -m cifar10 +# verify file exists +test -f cifar10_nxp_delegate.pte diff --git a/examples/nxp/setup.sh b/examples/nxp/setup.sh old mode 100644 new mode 100755