diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml
index bdab21af3da..1c4a13ece25 100644
--- a/.github/workflows/trunk.yml
+++ b/.github/workflows/trunk.yml
@@ -800,3 +800,6 @@ jobs:
 
         # Run pytest
         PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh
+        
+        # Run aot example: 
+        PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh
diff --git a/examples/nxp/README.md b/examples/nxp/README.md
new file mode 100644
index 00000000000..66ca0785b4c
--- /dev/null
+++ b/examples/nxp/README.md
@@ -0,0 +1,20 @@
+# PyTorch Model Delegation to Neutron Backend
+
+In this guide we will show how to use the ExecuTorch AoT flow to convert a PyTorch model to ExecuTorch format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend.
+
+First we will start with an example script converting the model. This example show the CifarNet model preparation. It is the same model which is part of the `example_cifarnet`
+
+The steps are expected to be executed from the executorch root folder.
+1. Run the setup.sh script to install the neutron-converter:
+```commandline
+$ examples/nxp/setup.sh
+```
+
+2. Now run the `aot_neutron_compile.py` example with the `cifar10` model 
+```commandline
+$ python -m examples.nxp.aot_neutron_compile --quantize \
+    --delegate --neutron_converter_flavor SDK_25_03 -m cifar10 
+```
+
+3. It will generate you `cifar10_nxp_delegate.pte` file which can be used with the MXUXpresso SDK `cifarnet_example` project, presented [here](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/example_applications.html#how-to-build-and-run-executorch-cifarnet-example).
+To get the MCUXpresso SDK follow this [guide](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/getting_mcuxpresso.html), use the MCUXpresso SDK v25.03.00. 
\ No newline at end of file
diff --git a/examples/nxp/aot_neutron_compile.py b/examples/nxp/aot_neutron_compile.py
new file mode 100644
index 00000000000..d8e4d324de2
--- /dev/null
+++ b/examples/nxp/aot_neutron_compile.py
@@ -0,0 +1,295 @@
+# Copyright 2024-2025 NXP
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Example script to compile the model for the NXP Neutron NPU
+
+import argparse
+import io
+import logging
+from collections import defaultdict
+from typing import Iterator
+
+import executorch.extension.pybindings.portable_lib
+import executorch.kernels.quantized  # noqa F401
+
+import torch
+
+from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
+from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
+from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
+from executorch.examples.models import MODEL_NAME_TO_MODEL
+from executorch.examples.models.model_factory import EagerModelFactory
+
+from executorch.exir import (
+    EdgeCompileConfig,
+    ExecutorchBackendConfig,
+    to_edge_transform_and_lower,
+)
+from executorch.extension.export_util import save_pte_program
+
+from torch.export import export
+from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
+
+from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
+
+FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+
+
+def print_ops_in_edge_program(edge_program):
+    """Find all ops used in the `edge_program` and print them out along with their occurrence counts."""
+
+    ops_and_counts = defaultdict(
+        lambda: 0
+    )  # Mapping ops to the numer of times they are used.
+    for node in edge_program.graph.nodes:
+        if "call" not in node.op:
+            continue  # `placeholder` or `output`. (not an operator)
+
+        if hasattr(node.target, "_schema"):
+            # Regular op.
+            # noinspection PyProtectedMember
+            op = node.target._schema.schema.name
+        else:
+            # Builtin function.
+            op = str(node.target)
+
+        ops_and_counts[op] += 1
+
+    # Sort the ops based on how many times they are used in the model.
+    ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True)
+
+    # Print the ops and use counts.
+    for op, count in ops_and_counts:
+        print(f"{op: <50} {count}x")
+
+
+def get_model_and_inputs_from_name(model_name: str):
+    """Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
+
+    Raises RuntimeError if there is no example model corresponding to the given name.
+    """
+
+    calibration_inputs = None
+    # Case 1: Model is defined in this file
+    if model_name in models.keys():
+        m = models[model_name]()
+        model = m.get_eager_model()
+        example_inputs = m.get_example_inputs()
+        calibration_inputs = m.get_calibration_inputs(64)
+    # Case 2: Model is defined in executorch/examples/models/
+    elif model_name in MODEL_NAME_TO_MODEL.keys():
+        logging.warning(
+            "Using a model from examples/models not all of these are currently supported"
+        )
+        model, example_inputs, _ = EagerModelFactory.create_model(
+            *MODEL_NAME_TO_MODEL[model_name]
+        )
+    else:
+        raise RuntimeError(
+            f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
+        )
+
+    return model, example_inputs, calibration_inputs
+
+
+models = {
+    "cifar10": CifarNet,
+}
+
+
+def post_training_quantize(
+    model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]]
+):
+    """Quantize the provided model.
+
+    :param model: Aten model to quantize.
+    :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
+                                input. Or an iterator over such tuples.
+    """
+    # Based on executorch.examples.arm.aot_amr_compiler.quantize
+    logging.info("Quantizing model")
+    logging.debug(f"---> Original model: {model}")
+    quantizer = NeutronQuantizer()
+
+    m = prepare_pt2e(model, quantizer)
+    # Calibration:
+    logging.debug("Calibrating model")
+
+    def _get_batch_size(data):
+        return data[0].shape[0]
+
+    if not isinstance(
+        calibration_inputs, tuple
+    ):  # Assumption that calibration_inputs is finite.
+        for i, data in enumerate(calibration_inputs):
+            if i % (1000 // _get_batch_size(data)) == 0:
+                logging.debug(f"{i * _get_batch_size(data)} calibration inputs done")
+            m(*data)
+    else:
+        m(*calibration_inputs)
+    m = convert_pt2e(m)
+    logging.debug(f"---> Quantized model: {m}")
+    return m
+
+
+if __name__ == "__main__":  # noqa C901
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-m",
+        "--model_name",
+        required=True,
+        help=f"Provide model name. Valid ones: {set(models.keys())}",
+    )
+    parser.add_argument(
+        "-d",
+        "--delegate",
+        action="store_true",
+        required=False,
+        default=False,
+        help="Flag for producing eIQ NeutronBackend delegated model",
+    )
+    parser.add_argument(
+        "--target",
+        required=False,
+        default="imxrt700",
+        help="Platform for running the delegated model",
+    )
+    parser.add_argument(
+        "-c",
+        "--neutron_converter_flavor",
+        required=False,
+        default="SDK_25_03",
+        help="Flavor of installed neutron-converter module. Neutron-converter module named "
+        "'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.",
+    )
+    parser.add_argument(
+        "-q",
+        "--quantize",
+        action="store_true",
+        required=False,
+        default=False,
+        help="Produce a quantized model",
+    )
+    parser.add_argument(
+        "-s",
+        "--so_library",
+        required=False,
+        default=None,
+        help="Path to custome kernel library",
+    )
+    parser.add_argument(
+        "--debug", action="store_true", help="Set the logging level to debug."
+    )
+    parser.add_argument(
+        "-t",
+        "--test",
+        action="store_true",
+        required=False,
+        default=False,
+        help="Test the selected model and print the accuracy between 0 and 1.",
+    )
+    parser.add_argument(
+        "--operators_not_to_delegate",
+        required=False,
+        default=[],
+        type=str,
+        nargs="*",
+        help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm",
+    )
+
+    args = parser.parse_args()
+
+    if args.debug:
+        logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
+
+    # 1. pick model from one of the supported lists
+    model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
+        args.model_name
+    )
+    model = model.eval()
+
+    # 2. Export the model to ATEN
+    exported_program = torch.export.export_for_training(
+        model, example_inputs, strict=True
+    )
+
+    module = exported_program.module()
+
+    # 4. Quantize if required
+    if args.quantize:
+        if calibration_inputs is None:
+            logging.warning(
+                "No calibration inputs available, using the example inputs instead"
+            )
+            calibration_inputs = example_inputs
+        module = post_training_quantize(module, calibration_inputs)
+
+    if args.so_library is not None:
+        logging.debug(f"Loading libraries: {args.so_library} and {args.portable_lib}")
+        torch.ops.load_library(args.so_library)
+
+    if args.test:
+        match args.model_name:
+            case "cifar10":
+                accuracy = test_cifarnet_model(module)
+
+            case _:
+                raise NotImplementedError(
+                    f"Testing of model `{args.model_name}` is not yet supported."
+                )
+
+        quantized_str = "quantized " if args.quantize else ""
+        print(f"\nAccuracy of the {quantized_str}`{args.model_name}`: {accuracy}\n")
+
+    # 5. Export to edge program
+    partitioner_list = []
+    if args.delegate is True:
+        partitioner_list = [
+            NeutronPartitioner(
+                generate_neutron_compile_spec(
+                    args.target,
+                    args.neutron_converter_flavor,
+                    operators_not_to_delegate=args.operators_not_to_delegate,
+                )
+            )
+        ]
+
+    edge_program = to_edge_transform_and_lower(
+        export(module, example_inputs, strict=True),
+        partitioner=partitioner_list,
+        compile_config=EdgeCompileConfig(
+            _check_ir_validity=False,
+        ),
+    )
+    logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
+
+    # 6. Export to ExecuTorch program
+    try:
+        exec_prog = edge_program.to_executorch(
+            config=ExecutorchBackendConfig(extract_delegate_segments=False)
+        )
+    except RuntimeError as e:
+        if "Missing out variants" in str(e.args[0]):
+            raise RuntimeError(
+                e.args[0]
+                + ".\nThis likely due to an external so library not being loaded. Supply a path to it with the "
+                "--portable_lib flag."
+            ).with_traceback(e.__traceback__) from None
+        else:
+            raise e
+
+    def executorch_program_to_str(ep, verbose=False):
+        f = io.StringIO()
+        ep.dump_executorch_program(out=f, verbose=verbose)
+        return f.getvalue()
+
+    logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")
+
+    # 7. Serialize to *.pte
+    model_name = f"{args.model_name}" + (
+        "_nxp_delegate" if args.delegate is True else ""
+    )
+    save_pte_program(exec_prog, model_name)
diff --git a/examples/nxp/experimental/cifar_net/cifar_net.pth b/examples/nxp/experimental/cifar_net/cifar_net.pth
new file mode 100644
index 00000000000..6dc4efde21d
Binary files /dev/null and b/examples/nxp/experimental/cifar_net/cifar_net.pth differ
diff --git a/examples/nxp/experimental/cifar_net/cifar_net.py b/examples/nxp/experimental/cifar_net/cifar_net.py
new file mode 100644
index 00000000000..1378d00cf12
--- /dev/null
+++ b/examples/nxp/experimental/cifar_net/cifar_net.py
@@ -0,0 +1,262 @@
+# Copyright 2024 NXP
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import itertools
+import logging
+import os.path
+from typing import Iterator, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+
+from executorch import exir
+from executorch.examples.models import model_base
+from torchvision import transforms
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class CifarNet(model_base.EagerModelBase):
+
+    def __init__(self, batch_size: int = 1, pth_file: str | None = None):
+        self.batch_size = batch_size
+        self.pth_file = pth_file or os.path.join(
+            os.path.dirname(__file__), "cifar_net.pth"
+        )
+
+    def get_eager_model(self) -> torch.nn.Module:
+        return get_model(self.batch_size, state_dict_file=self.pth_file)
+
+    def get_example_inputs(self) -> Tuple[torch.Tensor]:
+        tl = get_test_loader()
+        ds, _ = tl.dataset[
+            0
+        ]  # Dataset returns the data and the class. We need just the data.
+        return (ds.unsqueeze(0),)
+
+    def get_calibration_inputs(
+        self, batch_size: int = 1
+    ) -> Iterator[Tuple[torch.Tensor]]:
+        tl = get_test_loader(batch_size)
+
+        def _get_first(a, _):
+            return (a,)
+
+        return itertools.starmap(_get_first, iter(tl))
+
+
+class CifarNetModel(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(8, 32, 5)
+        self.conv2 = nn.Conv2d(32, 32, 5)
+        self.conv3 = nn.Conv2d(32, 64, 5)
+        self.pool1 = nn.MaxPool2d(2, 2)
+        self.pool2 = nn.MaxPool2d(1, 2)
+        self.fc = nn.Linear(1024, 10)
+        self.softmax = nn.Softmax(1)
+
+    def forward(self, x):
+
+        # Neutron Backend does not yet have passses for automated padding if number of channels does not
+        # fit to Neutron constrains (#channels == #MAC units). So define the model explicitly tailored for Neutron-C-64.
+        x = F.pad(x, (2, 2, 2, 2, 0, 5))
+        x = self.conv1(x)
+        x = self.pool1(x)
+
+        x = F.pad(x, (2, 2, 2, 2))
+        x = self.conv2(x)
+        x = self.pool1(x)
+
+        x = F.pad(x, (2, 2, 2, 2))
+        x = self.conv3(x)
+        x = self.pool2(x)
+
+        # The output of the previous MaxPool has shape [batch, 64, 4, 4] ([batch, 4, 4, 64] in Neutron IR). When running
+        #  inference of the `FullyConnected`, Neutron IR will automatically collapse the channels and spatial dimensions and
+        #  work with a tensor of shape [batch, 1024].
+        # PyTorch will combine the C and H with `batch`, and leave the last dimension (W). This will result in a tensor of
+        #  shape [batch * 256, 4]. This cannot be multiplied with the weight matrix of shape [1024, 10].
+        x = torch.reshape(x, (-1, 1024))
+        x = self.fc(x)
+        x = self.softmax(x)
+
+        return x
+
+
+def get_train_loader(batch_size: int = 1):
+    """Get loader for training data."""
+    transform = transforms.Compose(
+        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
+    )
+
+    train_set = torchvision.datasets.CIFAR10(
+        root="./data", train=True, download=True, transform=transform
+    )
+    train_loader = torch.utils.data.DataLoader(
+        train_set, batch_size=batch_size, shuffle=True, num_workers=0
+    )
+
+    return train_loader
+
+
+def get_test_loader(batch_size: int = 1):
+    """Get loader for testing data."""
+    transform = transforms.Compose(
+        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
+    )
+
+    test_set = torchvision.datasets.CIFAR10(
+        root="./data", train=False, download=True, transform=transform
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test_set, batch_size=batch_size, shuffle=False, num_workers=0
+    )
+
+    return test_loader
+
+
+def get_model(
+    batch_size: int = 1,
+    state_dict_file: str | None = None,
+    train: bool = False,
+    num_epochs: int = 1,
+) -> nn.Module:
+    """Create the CifarNet model.
+
+    :param batch_size: Batch size to use during training.
+    :param state_dict_file: `.pth` file. If provided and the file exists, weights will be loaded from it. Also after
+                             training, the weights will be stored in the file.
+    :param train: Boolean indicating whether to train the model.
+    :param num_epochs: Number of epochs to use during training.
+    :return: The loaded/trained CifarNet model.
+    """
+    cifar_net = CifarNetModel()
+
+    if state_dict_file is not None and os.path.isfile(state_dict_file):
+        # Load the pre-trained weights.
+        logger.info(f"Using pre-trained weights from `{state_dict_file}`.")
+        cifar_net.load_state_dict(torch.load(state_dict_file, weights_only=True))
+
+    if train:
+        # Train the model.
+        criterion = nn.CrossEntropyLoss()
+        optimizer = optim.SGD(cifar_net.parameters(), lr=0.0005, momentum=0.6)
+        train_loader = get_train_loader(batch_size)
+
+        for epoch in range(num_epochs):
+            running_loss = 0.0
+            for i, data in enumerate(train_loader, 0):
+                # get the inputs; data is a list of [inputs, labels]
+                inputs, labels = data
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward + backward + optimize
+                outputs = cifar_net(inputs)
+                loss = criterion(outputs, labels)
+                loss.backward()
+                optimizer.step()
+
+                # print statistics
+                running_loss += loss.item()
+                if i % 2000 == 1999:  # print every 2000 mini-batches
+                    print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
+                    running_loss = 0.0
+
+        logger.info("Finished training.")
+        if state_dict_file is not None and train:
+            logger.info(f"Saving the trained weights in `{state_dict_file}`.")
+            torch.save(cifar_net.state_dict(), state_dict_file)
+
+    return cifar_net
+
+
+def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]:
+    """Return a tuple containing 1 tensor (for the 1 model input) and the tensor will have shape
+    [`num_images`, 3, 32, 32].
+    """
+    loader = iter(get_train_loader(1))  # The train loader shuffles the images.
+    images = [image for image, _ in itertools.islice(loader, num_images)]
+    tensor = torch.vstack(images)
+    return (tensor,)
+
+
+def test_cifarnet_model(cifar_net: nn.Module, batch_size: int = 1) -> float:
+    """Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy.
+
+    This function may at some point in the future be integrated into the `CifarNet` class.
+
+    :param cifar_net: The model to test with the CifarNet10 testing dataset.
+    :return: The accuracy of the model (between 0 and 1).
+    """
+    correct = 0
+    total = 0
+    with torch.no_grad():
+        for data in get_test_loader(batch_size):
+            images, labels = data
+            outputs = cifar_net(images)
+            _, predicted = torch.max(outputs.data, 1)
+            total += labels.size(0)
+            correct += torch.sum(predicted == labels).item()
+
+    return correct / total
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--pte-file",
+        required=False,
+        help="Name of a `.pte` file to save the trained model in.",
+    )
+    parser.add_argument(
+        "--pth-file",
+        required=False,
+        type=str,
+        help="Name of a `.pth` file to save the trained weights in. If it already exists, the model "
+        "will be initialized with these weights.",
+    )
+    parser.add_argument(
+        "--train", required=False, action="store_true", help="Train the model."
+    )
+    parser.add_argument(
+        "--test", required=False, action="store_true", help="Test the trained model."
+    )
+    parser.add_argument("-b", "--batch-size", required=False, type=int, default=1)
+    parser.add_argument("-e", "--num-epochs", required=False, type=int, default=1)
+    args = parser.parse_args()
+
+    cifar_net = get_model(
+        state_dict_file=args.pth_file,
+        train=args.train,
+        batch_size=args.batch_size,
+        num_epochs=args.num_epochs,
+    )
+
+    if args.test:
+        logger.info("Running tests.")
+        accuracy = test_cifarnet_model(cifar_net, args.batch_size)
+        logger.info(f"Accuracy of the network on the 10000 test images: {accuracy}")
+
+    if args.pte_file is not None:
+        tracing_inputs = (torch.rand(args.batch_size, 3, 32, 32),)
+        aten_dialect_program = torch.export.export(cifar_net, tracing_inputs)
+        edge_dialect_program: exir.EdgeProgramManager = exir.to_edge(
+            aten_dialect_program
+        )
+        executorch_program = edge_dialect_program.to_executorch()
+
+        with open(args.pte_file, "wb") as file:
+            logger.info(f"Saving the trained model as `{args.pte_file}`.")
+            file.write(executorch_program.buffer)
diff --git a/examples/nxp/run_aot_example.sh b/examples/nxp/run_aot_example.sh
new file mode 100755
index 00000000000..1710490f6d7
--- /dev/null
+++ b/examples/nxp/run_aot_example.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+# Copyright 2025 NXP
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+set -eux
+
+SCRIPT_DIR=$(dirname $(readlink -fm $0))
+EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR))
+
+cd $EXECUTORCH_DIR
+
+# Run the AoT example
+python -m examples.nxp.aot_neutron_compile --quantize \
+    --delegate --neutron_converter_flavor SDK_25_03 -m cifar10
+# verify file exists
+test -f cifar10_nxp_delegate.pte
diff --git a/examples/nxp/setup.sh b/examples/nxp/setup.sh
old mode 100644
new mode 100755