|
21 | 21 | "showInput": false
|
22 | 22 | },
|
23 | 23 | "source": [
|
24 |
| - "This tutorial demonstrates how to apply the [TracInCP](https://arxiv.org/pdf/2002.08484.pdf) algorithm for influential examples interpretability from the Captum library. TracInCP calculates the influence score of a given training example on a given test example, which roughly speaking, represents how much higher the loss for the given test example would be if the given training example were removed from the training dataset, and the model re-trained. This functionality can be leveraged towards the following 2 use cases:\n", |
25 |
| - "1. For a single test example, identifying its influential examples. These are the training examples with the most positive influence scores (proponents) and the training examples with the most negative influence scores (opponents). \n", |
26 |
| - "2. Identifying mis-labelled data, i.e. examples in the training data whose \"ground-truth\" label is actually incorrect. The influence score of a mis-labelled example on itself (i.e. self influence score) will tend to be high. Thus to find mis-labelled examples, one can examine examples in order of decreasing self influence score.\n", |
| 24 | + "This tutorial demonstrates how to apply the [TracInCP](https://arxiv.org/pdf/2002.08484.pdf) algorithm for influential examples interpretability from the Captum library. TracInCP calculates the influence score of a given training example on a given test example, which roughly speaking, represents how much higher the loss for the given test example would be if the given training example were removed from the training dataset, and the model re-trained. This functionality can be leveraged towards the following two use cases:\n", |
| 25 | + "1. For a single test example, identifying its influential examples. These are the training examples with the most positive influence scores (proponents) and the training examples with the most negative influence scores (opponents).\n", |
| 26 | + "2. Identifying mislabelled data, i.e. examples in the training data whose \"ground-truth\" label is actually incorrect. The influence score of a mislabelled example on itself (i.e. self-influence score) will tend to be high. Thus to find mislabelled examples, one can examine examples in order of decreasing self-influence score.\n", |
27 | 27 | "\n",
|
28 | 28 | "\n",
|
29 | 29 | "TracInCP can be used for any trained Pytorch model for which several model checkpoints are available.\n",
|
30 | 30 | "\n",
|
31 | 31 | " \n",
|
32 | 32 | " **Note:** Before running this tutorial, please do the following:\n",
|
33 | 33 | " - install Captum.\n",
|
34 |
| - " - install the torchvision, and matplotlib packages.\n", |
| 34 | + " - install the torchvision and matplotlib packages.\n", |
35 | 35 | " - install the [Annoy](https://github.com/spotify/annoy) Python module."
|
36 | 36 | ]
|
37 | 37 | },
|
|
46 | 46 | },
|
47 | 47 | "source": [
|
48 | 48 | "## Overview of different implementations of the TracInCP method\n",
|
49 |
| - "Currently, Captum offers 3 implementations, all of which implement the same API. More specifically, they define an `influence` method, which can be used in 2 different modes:\n", |
50 |
| - "- influence score mode: given a batch of test examples, calculates the influence score of every example in the training dataset on every test example.\n", |
51 |
| - "- top-k most influential mode: given a batch of test examples, calculates either the proponents or opponents of every test example, as well as their corresponding influence scores.\n", |
| 49 | + "Currently, Captum offers 3 implementations, all of which implement the same API. More specifically, they define an `influence` method, which can be used in two different modes:\n", |
| 50 | + "- influence score mode: given a batch of test examples, calculate the influence score of every example in the training dataset on every test example.\n", |
| 51 | + "- top-k most influential mode: given a batch of test examples, calculate either the proponents or opponents of every test example, as well as their corresponding influence scores.\n", |
52 | 52 | "\n",
|
53 |
| - "The 3 different implementations are defined in the following classes:\n", |
54 |
| - "- `TracInCP`: considers gradients in all specified layers when computing influence scores. Specifying many layers will slow the execution of all 3 modes.\n", |
| 53 | + "The three different implementations are defined in the following classes:\n", |
| 54 | + "- `TracInCP`: considers gradients in all specified layers when computing influence scores. Specifying many layers will slow the execution of all three modes.\n", |
55 | 55 | "- `TracInCPFast`: In Appendix F of the TracIn paper, they show that if considering only gradients in the last fully-connected layer when computing influence scores, the computation can be done more quickly than naively applying backprop to compute gradients, using a computational trick. `TracInCPFast` computes influence scores, considering only the last fully-connected layer, using that trick. `TracInCPFast` is useful if you want to reduce the time and memory usage, relative to `TracInCP`.\n",
|
56 | 56 | "- `TracInCPFastRandProj`: The previous two classes were not meant for \"interactive\" use, because each call to `influence` in influence score mode or top-k most influential mode takes time proportional to the training dataset size. On the other hand, `TracInCPFastRandProj` enables \"interactive\" use, i.e. constant-time calls to `influence` for those two modes. The price we pay is that in `TracInCPFastRandProj.__init__`, pre-processing is done to store embeddings related to each training example into a nearest-neighbors data structure. This pre-processing takes both time and memory proportional to training dataset size. Furthermore, random projections can be applied to reduce memory usage, at the cost of the influence scores used in those two modes to be only approximately correct. Like `TracInCPFast`, this class only considers gradients in the last fully-connected layer, and is useful if you want to reduce the time and memory usage, relative to `TracInCP`."
|
57 | 57 | ]
|
|
120 | 120 | },
|
121 | 121 | "source": [
|
122 | 122 | "#### Define `net`\n",
|
123 |
| - "We will use a relatively simple model from the following tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py" |
| 123 | + "We will use a relatively simple model from the following tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py." |
124 | 124 | ]
|
125 | 125 | },
|
126 | 126 | {
|
|
153 | 153 | " def __init__(self):\n",
|
154 | 154 | " super(Net, self).__init__()\n",
|
155 | 155 | " self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
156 |
| - " self.pool1 = nn.MaxPool2d(2, 2)\n", |
157 |
| - " self.pool2 = nn.MaxPool2d(2, 2)\n", |
158 | 156 | " self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
| 157 | + " self.pool1 = nn.MaxPool2d(2, 2)\n", |
| 158 | + " self.pool2 = nn.MaxPool2d(2, 2)\n", |
159 | 159 | " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
|
160 | 160 | " self.fc2 = nn.Linear(120, 84)\n",
|
161 | 161 | " self.fc3 = nn.Linear(84, 10)\n",
|
|
280 | 280 | },
|
281 | 281 | "source": [
|
282 | 282 | "#### Define `test_dataset`\n",
|
283 |
| - "This will be the same as `correct_dataset`, so that it shares the same path and transform. The only difference is that that it uses the validation split" |
| 283 | + "This will be the same as `correct_dataset`, so that it shares the same path and transform. The only difference is that it uses the validation split." |
284 | 284 | ]
|
285 | 285 | },
|
286 | 286 | {
|
|
329 | 329 | "showInput": false
|
330 | 330 | },
|
331 | 331 | "source": [
|
332 |
| - "We first define a training function, which is copied from the above tutorial" |
| 332 | + "We first define a training function, which is copied from the above tutorial." |
333 | 333 | ]
|
334 | 334 | },
|
335 | 335 | {
|
|
558 | 558 | "showInput": false
|
559 | 559 | },
|
560 | 560 | "source": [
|
561 |
| - "We first load `net` with the last checkpoint so that the predictions we make in the next cell will be for the trained model. We save this last checkpoint as `correct_dataset_final_checkpoint`, because it turns out we will re-use this checkpoint later on." |
| 561 | + "We first load `net` with the last checkpoint so that the predictions we make in the next cell will be for the trained model. We save this last checkpoint as `correct_dataset_final_checkpoint`, because we will re-use this checkpoint later on." |
562 | 562 | ]
|
563 | 563 | },
|
564 | 564 | {
|
|
1512 | 1512 | "source": [
|
1513 | 1513 | "# Identifying mislabelled data\n",
|
1514 | 1514 | "Now, we will illustrate the ability of TracInCP to identify mislabelled data. As before, we need 3 components:\n",
|
1515 |
| - "- a Pytorch model. We will continue to use `net`\n", |
1516 |
| - "- a Pytorch `Dataset` used to train `net`, `mislabelled_dataset`.\n", |
| 1515 | + "- A Pytorch model. We will continue to use `net`.\n", |
| 1516 | + "- A Pytorch `Dataset` used to train `net`, `mislabelled_dataset`.\n", |
1517 | 1517 | "- A Pytorch `Dataset`, `test_dataset`. For this, we will use the original CIFAR-10 validation split. Unlike when using TracInCP to identify influential examples for certain test examples, we only use this for monitoring training; we are just identifying mislabelled examples in the training data. Also note that unlike `mislabelled_dataset`, `test_dataset` will not have mislabelled examples.\n",
|
1518 | 1518 | "- checkpoints from training `net` with `mislabelled_dataset`"
|
1519 | 1519 | ]
|
|
1561 | 1561 | "source": [
|
1562 | 1562 | "#### Define `mislabelled_dataset`\n",
|
1563 | 1563 | "\n",
|
1564 |
| - "We now define `mislabelled_dataset` by artificially introducing mis-labelled examples into `correct_dataset`. Using artificial data lets us know the ground-truth for whether examples really are mis-labelled, and thus do evaluation. We create `mislabelled_dataset` from `correct_dataset` using the following procedure: We initialize the Pytorch model, trained using `correct_dataset`, as `correct_dataset_net`. For 10% of the examples in `correct_dataset`, we use `correct_dataset_net` to predict the probability the example belongs to each class. We then change the label to the most probable label that is *incorrect*.\n", |
| 1564 | + "We now define `mislabelled_dataset` by artificially introducing mislabelled examples into `correct_dataset`. Using artificial data lets us know the ground-truth for whether examples really are mis-labelled, and thus do evaluation. We create `mislabelled_dataset` from `correct_dataset` using the following procedure: We initialize the Pytorch model, trained using `correct_dataset`, as `correct_dataset_net`. For 10% of the examples in `correct_dataset`, we use `correct_dataset_net` to predict the probability the example belongs to each class. We then change the label to the most probable label that is *incorrect*.\n", |
1565 | 1565 | "\n",
|
1566 | 1566 | "Note that to know the ground truth for which examples in `mislabelled_dataset` are mislabelled, we can compare the labels between `mislabelled_dataset` and `correct_dataset`. Also note that since both datasets have the same features, `mislabelled_dataset` is defined in terms of `correct_dataset`."
|
1567 | 1567 | ]
|
|
0 commit comments