|
| 1 | +# Copyright 2024 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import argparse |
| 7 | +import itertools |
| 8 | +import logging |
| 9 | +import os.path |
| 10 | +from typing import Iterator, Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +import torch.nn.functional as F |
| 15 | +import torch.optim as optim |
| 16 | +import torchvision |
| 17 | + |
| 18 | +from executorch import exir |
| 19 | +from executorch.examples.models import model_base |
| 20 | +from torchvision import transforms |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | +logger.setLevel(logging.INFO) |
| 24 | + |
| 25 | + |
| 26 | +class CifarNet(model_base.EagerModelBase): |
| 27 | + |
| 28 | + def __init__(self, batch_size: int = 1, pth_file: str | None = None): |
| 29 | + self.batch_size = batch_size |
| 30 | + self.pth_file = pth_file or os.path.join( |
| 31 | + os.path.dirname(__file__), "cifar_net.pth" |
| 32 | + ) |
| 33 | + |
| 34 | + def get_eager_model(self) -> torch.nn.Module: |
| 35 | + return get_model(self.batch_size, state_dict_file=self.pth_file) |
| 36 | + |
| 37 | + def get_example_inputs(self) -> Tuple[torch.Tensor]: |
| 38 | + tl = get_test_loader() |
| 39 | + ds, _ = tl.dataset[ |
| 40 | + 0 |
| 41 | + ] # Dataset returns the data and the class. We need just the data. |
| 42 | + return (ds.unsqueeze(0),) |
| 43 | + |
| 44 | + def get_calibration_inputs( |
| 45 | + self, batch_size: int = 1 |
| 46 | + ) -> Iterator[Tuple[torch.Tensor]]: |
| 47 | + tl = get_test_loader(batch_size) |
| 48 | + |
| 49 | + def _get_first(a, _): |
| 50 | + return (a,) |
| 51 | + |
| 52 | + return itertools.starmap(_get_first, iter(tl)) |
| 53 | + |
| 54 | + |
| 55 | +class CifarNetModel(nn.Module): |
| 56 | + |
| 57 | + def __init__(self): |
| 58 | + super().__init__() |
| 59 | + |
| 60 | + self.conv1 = nn.Conv2d(8, 32, 5) |
| 61 | + self.conv2 = nn.Conv2d(32, 32, 5) |
| 62 | + self.conv3 = nn.Conv2d(32, 64, 5) |
| 63 | + self.pool1 = nn.MaxPool2d(2, 2) |
| 64 | + self.pool2 = nn.MaxPool2d(1, 2) |
| 65 | + self.fc = nn.Linear(1024, 10) |
| 66 | + self.softmax = nn.Softmax(1) |
| 67 | + |
| 68 | + def forward(self, x): |
| 69 | + |
| 70 | + # Neutron Backend does not yet have passses for automated padding if number of channels does not |
| 71 | + # fit to Neutron constrains (#channels == #MAC units). So define the model explicitly tailored for Neutron-C-64. |
| 72 | + x = F.pad(x, (2, 2, 2, 2, 0, 5)) |
| 73 | + x = self.conv1(x) |
| 74 | + x = self.pool1(x) |
| 75 | + |
| 76 | + x = F.pad(x, (2, 2, 2, 2)) |
| 77 | + x = self.conv2(x) |
| 78 | + x = self.pool1(x) |
| 79 | + |
| 80 | + x = F.pad(x, (2, 2, 2, 2)) |
| 81 | + x = self.conv3(x) |
| 82 | + x = self.pool2(x) |
| 83 | + |
| 84 | + # The output of the previous MaxPool has shape [batch, 64, 4, 4] ([batch, 4, 4, 64] in Neutron IR). When running |
| 85 | + # inference of the `FullyConnected`, Neutron IR will automatically collapse the channels and spatial dimensions and |
| 86 | + # work with a tensor of shape [batch, 1024]. |
| 87 | + # PyTorch will combine the C and H with `batch`, and leave the last dimension (W). This will result in a tensor of |
| 88 | + # shape [batch * 256, 4]. This cannot be multiplied with the weight matrix of shape [1024, 10]. |
| 89 | + x = torch.reshape(x, (-1, 1024)) |
| 90 | + x = self.fc(x) |
| 91 | + x = self.softmax(x) |
| 92 | + |
| 93 | + return x |
| 94 | + |
| 95 | + |
| 96 | +def get_train_loader(batch_size: int = 1): |
| 97 | + """Get loader for training data.""" |
| 98 | + transform = transforms.Compose( |
| 99 | + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| 100 | + ) |
| 101 | + |
| 102 | + train_set = torchvision.datasets.CIFAR10( |
| 103 | + root="./data", train=True, download=True, transform=transform |
| 104 | + ) |
| 105 | + train_loader = torch.utils.data.DataLoader( |
| 106 | + train_set, batch_size=batch_size, shuffle=True, num_workers=0 |
| 107 | + ) |
| 108 | + |
| 109 | + return train_loader |
| 110 | + |
| 111 | + |
| 112 | +def get_test_loader(batch_size: int = 1): |
| 113 | + """Get loader for testing data.""" |
| 114 | + transform = transforms.Compose( |
| 115 | + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| 116 | + ) |
| 117 | + |
| 118 | + test_set = torchvision.datasets.CIFAR10( |
| 119 | + root="./data", train=False, download=True, transform=transform |
| 120 | + ) |
| 121 | + test_loader = torch.utils.data.DataLoader( |
| 122 | + test_set, batch_size=batch_size, shuffle=False, num_workers=0 |
| 123 | + ) |
| 124 | + |
| 125 | + return test_loader |
| 126 | + |
| 127 | + |
| 128 | +def get_model( |
| 129 | + batch_size: int = 1, |
| 130 | + state_dict_file: str | None = None, |
| 131 | + train: bool = False, |
| 132 | + num_epochs: int = 1, |
| 133 | +) -> nn.Module: |
| 134 | + """Create the CifarNet model. |
| 135 | +
|
| 136 | + :param batch_size: Batch size to use during training. |
| 137 | + :param state_dict_file: `.pth` file. If provided and the file exists, weights will be loaded from it. Also after |
| 138 | + training, the weights will be stored in the file. |
| 139 | + :param train: Boolean indicating whether to train the model. |
| 140 | + :param num_epochs: Number of epochs to use during training. |
| 141 | + :return: The loaded/trained CifarNet model. |
| 142 | + """ |
| 143 | + cifar_net = CifarNetModel() |
| 144 | + |
| 145 | + if state_dict_file is not None and os.path.isfile(state_dict_file): |
| 146 | + # Load the pre-trained weights. |
| 147 | + logger.info(f"Using pre-trained weights from `{state_dict_file}`.") |
| 148 | + cifar_net.load_state_dict(torch.load(state_dict_file, weights_only=True)) |
| 149 | + |
| 150 | + if train: |
| 151 | + # Train the model. |
| 152 | + criterion = nn.CrossEntropyLoss() |
| 153 | + optimizer = optim.SGD(cifar_net.parameters(), lr=0.0005, momentum=0.6) |
| 154 | + train_loader = get_train_loader(batch_size) |
| 155 | + |
| 156 | + for epoch in range(num_epochs): |
| 157 | + running_loss = 0.0 |
| 158 | + for i, data in enumerate(train_loader, 0): |
| 159 | + # get the inputs; data is a list of [inputs, labels] |
| 160 | + inputs, labels = data |
| 161 | + |
| 162 | + # zero the parameter gradients |
| 163 | + optimizer.zero_grad() |
| 164 | + |
| 165 | + # forward + backward + optimize |
| 166 | + outputs = cifar_net(inputs) |
| 167 | + loss = criterion(outputs, labels) |
| 168 | + loss.backward() |
| 169 | + optimizer.step() |
| 170 | + |
| 171 | + # print statistics |
| 172 | + running_loss += loss.item() |
| 173 | + if i % 2000 == 1999: # print every 2000 mini-batches |
| 174 | + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") |
| 175 | + running_loss = 0.0 |
| 176 | + |
| 177 | + logger.info("Finished training.") |
| 178 | + if state_dict_file is not None and train: |
| 179 | + logger.info(f"Saving the trained weights in `{state_dict_file}`.") |
| 180 | + torch.save(cifar_net.state_dict(), state_dict_file) |
| 181 | + |
| 182 | + return cifar_net |
| 183 | + |
| 184 | + |
| 185 | +def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]: |
| 186 | + """Return a tuple containing 1 tensor (for the 1 model input) and the tensor will have shape |
| 187 | + [`num_images`, 3, 32, 32]. |
| 188 | + """ |
| 189 | + loader = iter(get_train_loader(1)) # The train loader shuffles the images. |
| 190 | + images = [image for image, _ in itertools.islice(loader, num_images)] |
| 191 | + tensor = torch.vstack(images) |
| 192 | + return (tensor,) |
| 193 | + |
| 194 | + |
| 195 | +def test_cifarnet_model(cifar_net: nn.Module, batch_size: int = 1) -> float: |
| 196 | + """Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy. |
| 197 | +
|
| 198 | + This function may at some point in the future be integrated into the `CifarNet` class. |
| 199 | +
|
| 200 | + :param cifar_net: The model to test with the CifarNet10 testing dataset. |
| 201 | + :return: The accuracy of the model (between 0 and 1). |
| 202 | + """ |
| 203 | + correct = 0 |
| 204 | + total = 0 |
| 205 | + with torch.no_grad(): |
| 206 | + for data in get_test_loader(batch_size): |
| 207 | + images, labels = data |
| 208 | + outputs = cifar_net(images) |
| 209 | + _, predicted = torch.max(outputs.data, 1) |
| 210 | + total += labels.size(0) |
| 211 | + correct += torch.sum(predicted == labels).item() |
| 212 | + |
| 213 | + return correct / total |
| 214 | + |
| 215 | + |
| 216 | +if __name__ == "__main__": |
| 217 | + parser = argparse.ArgumentParser() |
| 218 | + parser.add_argument( |
| 219 | + "--pte-file", |
| 220 | + required=False, |
| 221 | + help="Name of a `.pte` file to save the trained model in.", |
| 222 | + ) |
| 223 | + parser.add_argument( |
| 224 | + "--pth-file", |
| 225 | + required=False, |
| 226 | + type=str, |
| 227 | + help="Name of a `.pth` file to save the trained weights in. If it already exists, the model " |
| 228 | + "will be initialized with these weights.", |
| 229 | + ) |
| 230 | + parser.add_argument( |
| 231 | + "--train", required=False, action="store_true", help="Train the model." |
| 232 | + ) |
| 233 | + parser.add_argument( |
| 234 | + "--test", required=False, action="store_true", help="Test the trained model." |
| 235 | + ) |
| 236 | + parser.add_argument("-b", "--batch-size", required=False, type=int, default=1) |
| 237 | + parser.add_argument("-e", "--num-epochs", required=False, type=int, default=1) |
| 238 | + args = parser.parse_args() |
| 239 | + |
| 240 | + cifar_net = get_model( |
| 241 | + state_dict_file=args.pth_file, |
| 242 | + train=args.train, |
| 243 | + batch_size=args.batch_size, |
| 244 | + num_epochs=args.num_epochs, |
| 245 | + ) |
| 246 | + |
| 247 | + if args.test: |
| 248 | + logger.info("Running tests.") |
| 249 | + accuracy = test_cifarnet_model(cifar_net, args.batch_size) |
| 250 | + logger.info(f"Accuracy of the network on the 10000 test images: {accuracy}") |
| 251 | + |
| 252 | + if args.pte_file is not None: |
| 253 | + tracing_inputs = (torch.rand(args.batch_size, 3, 32, 32),) |
| 254 | + aten_dialect_program = torch.export.export(cifar_net, tracing_inputs) |
| 255 | + edge_dialect_program: exir.EdgeProgramManager = exir.to_edge( |
| 256 | + aten_dialect_program |
| 257 | + ) |
| 258 | + executorch_program = edge_dialect_program.to_executorch() |
| 259 | + |
| 260 | + with open(args.pte_file, "wb") as file: |
| 261 | + logger.info(f"Saving the trained model as `{args.pte_file}`.") |
| 262 | + file.write(executorch_program.buffer) |
0 commit comments