diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index e9fba1ba27..50cb903fd0 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -1,256 +1,324 @@ -import math -from inspect import signature -from typing import Dict, List, Tuple, Type, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from captum.optim._core.output_hook import ActivationFetcher -from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType - - -def get_model_layers(model: nn.Module) -> List[str]: - """ - Return a list of hookable layers for the target model. - """ - layers = [] - - def get_layers(net: nn.Module, prefix: List = []) -> None: - if hasattr(net, "_modules"): - for name, layer in net._modules.items(): - if layer is None: - continue - separator = "" if str(name).isdigit() else "." - name = "[" + str(name) + "]" if str(name).isdigit() else name - layers.append(separator.join(prefix + [name])) - get_layers(layer, prefix=prefix + [name]) - - get_layers(model) - return layers - - -class RedirectedReLU(torch.autograd.Function): - """ - A workaround when there is no gradient flow from an initial random input. - ReLU layers will block the gradient flow during backpropagation when their - input is less than 0. This means that it can be impossible to visualize a - target without allowing negative values to pass through ReLU layers during - backpropagation. - See: - https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py - """ - - @staticmethod - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - self.save_for_backward(input_tensor) - return input_tensor.clamp(min=0) - - @staticmethod - def backward(self, grad_output: torch.Tensor) -> torch.Tensor: - (input_tensor,) = self.saved_tensors - relu_grad = grad_output.clone() - relu_grad[input_tensor < 0] = 0 - if torch.equal(relu_grad, torch.zeros_like(relu_grad)): - # Let "wrong" gradients flow if gradient is completely 0 - return grad_output.clone() - return relu_grad - - -class RedirectedReluLayer(nn.Module): - """ - Class for applying RedirectedReLU - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return RedirectedReLU.apply(input) - - -def replace_layers( - model: nn.Module, - layer1: Type[nn.Module], - layer2: Type[nn.Module], - transfer_vars: bool = False, - **kwargs -) -> None: - """ - Replace all target layers with new layers inside the specified model, - possibly with the same initialization variables. - - Args: - model: (nn.Module): A PyTorch model instance. - layer1: (Type[nn.Module]): The layer class that you want to transfer - initialization variables from. - layer2: (Type[nn.Module]): The layer class to create with the variables - from layer1. - transfer_vars (bool, optional): Wether or not to try and copy - initialization variables from layer1 instances to the replacement - layer2 instances. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - """ - - for name, child in model._modules.items(): - if isinstance(child, layer1): - if transfer_vars: - new_layer = _transfer_layer_vars(child, layer2, **kwargs) - else: - new_layer = layer2(**kwargs) - setattr(model, name, new_layer) - elif child is not None: - replace_layers(child, layer1, layer2, transfer_vars, **kwargs) - - -def _transfer_layer_vars( - layer1: nn.Module, layer2: Type[nn.Module], **kwargs -) -> nn.Module: - """ - Given a layer instance, create a new layer instance of another class - with the same initialization variables as the original layer. - Args: - layer1: (nn.Module): A layer instance that you want to transfer - initialization variables from. - layer2: (nn.Module): The layer class to create with the variables - from of layer1. - kwargs: (Any, optional): Any additional variables to use when creating - the new layer. - Returns: - layer2 instance (nn.Module): An instance of layer2 with the initialization - variables that it shares with layer1, and any specified additional - initialization variables. - """ - - l2_vars = list(signature(layer2.__init__).parameters.values()) - l2_vars = [ - str(l2_vars[i]).split()[0] - for i in range(len(l2_vars)) - if str(l2_vars[i]) != "self" - ] - l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] - l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] - layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} - - layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} - shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} - new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) - return layer2(**new_vars) - - -class Conv2dSame(nn.Conv2d): - """ - Tensorflow like 'SAME' convolution wrapper for 2D convolutions. - TODO: Replace with torch.nn.Conv2d when support for padding='same' - is in stable version - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 0, - dilation: Union[int, Tuple[int, int]] = 1, - groups: int = 1, - bias: bool = True, - ) -> None: - super().__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias - ) - - def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) - pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) - - if pad_h > 0 or pad_w > 0: - x = F.pad( - x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] - ) - return F.conv2d( - x, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -def collect_activations( - model: nn.Module, - targets: Union[nn.Module, List[nn.Module]], - model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), -) -> ModuleOutputMapping: - """ - Collect target activations for a model. - """ - if not hasattr(targets, "__iter__"): - targets = [targets] - catch_activ = ActivationFetcher(model, targets) - activ_out = catch_activ(model_input) - return activ_out - - -class SkipLayer(torch.nn.Module): - """ - This layer is made to take the place of any layer that needs to be skipped over - during the forward pass. Use cases include removing nonlinear activation layers - like ReLU for circuits research. - - This layer works almost exactly the same way that nn.Indentiy does, except it also - ignores any additional arguments passed to the forward function. Any layer replaced - by SkipLayer must have the same input and output shapes. - - See nn.Identity for more details: - https://pytorch.org/docs/stable/generated/torch.nn.Identity.html - - Args: - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__() - - def forward( - self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - """ - Args: - x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. - args (Any): Any argument. Arguments will be safely ignored. - kwargs (Any) Any keyword argument. Arguments will be safely ignored. - Returns: - x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or - tensors. - """ - return x - - -def skip_layers( - model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] -) -> None: - """ - This function is a wrapper function for - replace_layers and replaces the target layer - with layers that do nothing. - This is useful for removing the nonlinear ReLU - layers when creating expanded weights. - Args: - model (nn.Module): A PyTorch model instance. - layers (nn.Module or list of nn.Module): The layer - class type to replace in the model. - """ - if not hasattr(layers, "__iter__"): - layers = cast(Type[nn.Module], layers) - replace_layers(model, layers, SkipLayer) - else: - layers = cast(List[Type[nn.Module]], layers) - for target_layer in layers: - replace_layers(model, target_layer, SkipLayer) +import math +from inspect import signature +from typing import Dict, List, Optional, Tuple, Type, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from captum.optim._core.output_hook import ActivationFetcher +from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType + + +def get_model_layers(model: nn.Module) -> List[str]: + """ + Return a list of hookable layers for the target model. + """ + layers = [] + + def get_layers(net: nn.Module, prefix: List = []) -> None: + if hasattr(net, "_modules"): + for name, layer in net._modules.items(): + if layer is None: + continue + separator = "" if str(name).isdigit() else "." + name = "[" + str(name) + "]" if str(name).isdigit() else name + layers.append(separator.join(prefix + [name])) + get_layers(layer, prefix=prefix + [name]) + + get_layers(model) + return layers + + +class RedirectedReLU(torch.autograd.Function): + """ + A workaround when there is no gradient flow from an initial random input. + ReLU layers will block the gradient flow during backpropagation when their + input is less than 0. This means that it can be impossible to visualize a + target without allowing negative values to pass through ReLU layers during + backpropagation. + See: + https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py + """ + + @staticmethod + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + self.save_for_backward(input_tensor) + return input_tensor.clamp(min=0) + + @staticmethod + def backward(self, grad_output: torch.Tensor) -> torch.Tensor: + (input_tensor,) = self.saved_tensors + relu_grad = grad_output.clone() + relu_grad[input_tensor < 0] = 0 + if torch.equal(relu_grad, torch.zeros_like(relu_grad)): + # Let "wrong" gradients flow if gradient is completely 0 + return grad_output.clone() + return relu_grad + + +class RedirectedReluLayer(nn.Module): + """ + Class for applying RedirectedReLU + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return RedirectedReLU.apply(input) + + +def replace_layers( + model: nn.Module, + layer1: Type[nn.Module], + layer2: Type[nn.Module], + transfer_vars: bool = False, + **kwargs +) -> None: + """ + Replace all target layers with new layers inside the specified model, + possibly with the same initialization variables. + + Args: + model: (nn.Module): A PyTorch model instance. + layer1: (Type[nn.Module]): The layer class that you want to transfer + initialization variables from. + layer2: (Type[nn.Module]): The layer class to create with the variables + from layer1. + transfer_vars (bool, optional): Wether or not to try and copy + initialization variables from layer1 instances to the replacement + layer2 instances. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + """ + + for name, child in model._modules.items(): + if isinstance(child, layer1): + if transfer_vars: + new_layer = _transfer_layer_vars(child, layer2, **kwargs) + else: + new_layer = layer2(**kwargs) + setattr(model, name, new_layer) + elif child is not None: + replace_layers(child, layer1, layer2, transfer_vars, **kwargs) + + +def _transfer_layer_vars( + layer1: nn.Module, layer2: Type[nn.Module], **kwargs +) -> nn.Module: + """ + Given a layer instance, create a new layer instance of another class + with the same initialization variables as the original layer. + Args: + layer1: (nn.Module): A layer instance that you want to transfer + initialization variables from. + layer2: (nn.Module): The layer class to create with the variables + from of layer1. + kwargs: (Any, optional): Any additional variables to use when creating + the new layer. + Returns: + layer2 instance (nn.Module): An instance of layer2 with the initialization + variables that it shares with layer1, and any specified additional + initialization variables. + """ + + l2_vars = list(signature(layer2.__init__).parameters.values()) + l2_vars = [ + str(l2_vars[i]).split()[0] + for i in range(len(l2_vars)) + if str(l2_vars[i]) != "self" + ] + l2_vars = [p.split(":")[0] if ":" in p and "=" not in p else p for p in l2_vars] + l2_vars = [p.split("=")[0] if "=" in p and ":" not in p else p for p in l2_vars] + layer2_vars: Dict = {k: [] for k in dict.fromkeys(l2_vars).keys()} + + layer1_vars = {k: v for k, v in vars(layer1).items() if not k.startswith("_")} + shared_vars = {k: v for k, v in layer1_vars.items() if k in layer2_vars} + new_vars = dict(item for d in (shared_vars, kwargs) for item in d.items()) + return layer2(**new_vars) + + +class Conv2dSame(nn.Conv2d): + """ + Tensorflow like 'SAME' convolution wrapper for 2D convolutions. + TODO: Replace with torch.nn.Conv2d when support for padding='same' + is in stable version + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + ) -> None: + super().__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + pad_h = self.calc_same_pad(i=ih, k=kh, s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=kw, s=self.stride[1], d=self.dilation[1]) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +def collect_activations( + model: nn.Module, + targets: Union[nn.Module, List[nn.Module]], + model_input: TupleOfTensorsOrTensorType = torch.zeros(1, 3, 224, 224), +) -> ModuleOutputMapping: + """ + Collect target activations for a model. + """ + if not hasattr(targets, "__iter__"): + targets = [targets] + catch_activ = ActivationFetcher(model, targets) + activ_out = catch_activ(model_input) + return activ_out + + +class SkipLayer(torch.nn.Module): + """ + This layer is made to take the place of any layer that needs to be skipped over + during the forward pass. Use cases include removing nonlinear activation layers + like ReLU for circuits research. + + This layer works almost exactly the same way that nn.Indentiy does, except it also + ignores any additional arguments passed to the forward function. Any layer replaced + by SkipLayer must have the same input and output shapes. + + See nn.Identity for more details: + https://pytorch.org/docs/stable/generated/torch.nn.Identity.html + + Args: + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward( + self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors. + args (Any): Any argument. Arguments will be safely ignored. + kwargs (Any) Any keyword argument. Arguments will be safely ignored. + Returns: + x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or + tensors. + """ + return x + + +def skip_layers( + model: nn.Module, layers: Union[List[Type[nn.Module]], Type[nn.Module]] +) -> None: + """ + This function is a wrapper function for + replace_layers and replaces the target layer + with layers that do nothing. + This is useful for removing the nonlinear ReLU + layers when creating expanded weights. + Args: + model (nn.Module): A PyTorch model instance. + layers (nn.Module or list of nn.Module): The layer + class type to replace in the model. + """ + if not hasattr(layers, "__iter__"): + layers = cast(Type[nn.Module], layers) + replace_layers(model, layers, SkipLayer) + else: + layers = cast(List[Type[nn.Module]], layers) + for target_layer in layers: + replace_layers(model, target_layer, SkipLayer) + + +class MaxPool2dRelaxed(torch.nn.Module): + """ + A relaxed pooling layer, that's useful for calculating attributions of spatial + positions. Noise in the gradient is reduced by the continuous relaxation of the + gradient of models using this layer. + + This layer is meant to be combined with forward-mode AD, so that the class + attributions of spatial posititions can be estimated using the rate at which + increasing the neuron affects the output classes. + + This layer peforms a MaxPool2d operation on the input, while using an equivalent + AvgPool2d layer to compute the gradient. This means that the forward pass returns + nn.MaxPool2d(input) while the backward pass uses nn.AvgPool2d(input). + + Carter, et al., "Activation Atlas", Distill, 2019. + https://distill.pub/2019/activation-atlas/ + + The Lucid equivalent of this class can be found here: + https://github.com/ + tensorflow/lucid/blob/master/lucid/optvis/overrides/smoothed_maxpool_grad.py + + An additional Lucid reference implementation can be found here: + https://colab.research.google.com/github/tensorflow/ + lucid/blob/master/notebooks/building-blocks/AttrSpatial.ipynb + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, ...]], + stride: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + ceil_mode: bool = False, + ) -> None: + """ + Args: + + kernel_size (int or tuple of int): The size of the window to perform max & + average pooling with. + stride (int or tuple of int, optional): The stride window size to use. + Default: None + padding (int or tuple of int): The amount of zero padding to add to both + sides in the nn.MaxPool2d & nn.AvgPool2d modules. + Default: 0 + ceil_mode (bool, optional): Whether to use ceil or floor for creating the + output shape. + Default: False + """ + super().__init__() + self.maxpool = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + self.avgpool = torch.nn.AvgPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run the pooling operations on. + + Returns: + x (torch.Tensor): A max pooled x tensor with gradient of an equivalent avg + pooled tensor. + """ + return self.maxpool(x.detach()) + self.avgpool(x) - self.avgpool(x.detach()) diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index f6418b8d6c..176b10fff2 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -290,3 +290,48 @@ def test_skip_layers(self) -> None: model_utils.skip_layers(model, torch.nn.ReLU) output_tensor = model(x) assertTensorAlmostEqual(self, x, output_tensor, 0) + + +class TestMaxPool2dRelaxed(BaseTest): + def test_maxpool2d_relaxed_forward_data(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) + + test_input = torch.arange(0, 1 * 3 * 8 * 8).view(1, 3, 8, 8).float() + + test_output_relaxed = maxpool_relaxed(test_input.clone()) + test_output_max = maxpool(test_input.clone()) + + assertTensorAlmostEqual(self, test_output_relaxed, test_output_max) + + def test_maxpool2d_relaxed_gradient(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + test_input = torch.nn.Parameter( + torch.arange(0, 1 * 1 * 4 * 4).view(1, 1, 4, 4).float() + ) + + test_output = maxpool_relaxed(test_input) + + output_grad = torch.autograd.grad( + outputs=[test_output], + inputs=[test_input], + grad_outputs=[test_output], + )[0] + + expected_output = torch.tensor( + [ + [ + [ + [1.1111, 1.1111, 2.9444, 1.8333], + [1.1111, 1.1111, 2.9444, 1.8333], + [3.4444, 3.4444, 9.0278, 5.5833], + [2.3333, 2.3333, 6.0833, 3.7500], + ] + ] + ], + ) + assertTensorAlmostEqual(self, output_grad, expected_output, 0.0005) diff --git a/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb new file mode 100644 index 0000000000..4e30a2c479 --- /dev/null +++ b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb @@ -0,0 +1,621 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "ActivationAtlasSampleCollection_OptimViz.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "KP2PKna21WLK" + }, + "source": [ + "# Collecting Samples for Activation Atlases with captum.optim\n", + "\n", + "This notebook demonstrates how to collect the activation and corresponding attribution samples required for [Activation Atlases](https://distill.pub/2019/activation-atlas/) for the InceptionV1 model imported from Caffe." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v6T6jxWb4cil" + }, + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from typing import List, Optional, Tuple, cast\n", + "\n", + "import os\n", + "import torch\n", + "import torchvision\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from captum.optim.models import googlenet\n", + "\n", + "import captum.optim as opt\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dtE-t6ZG0-sJ" + }, + "source": [ + "### Dataset Download & Setup \n", + "\n", + "To begin, we'll need to download and setup the image dataset that our model was trained on. You can download ImageNet's ILSVRC2012 dataset from the [ImageNet website](http://www.image-net.org/challenges/LSVRC/2012/) or via BitTorrent from [Academic Torrents](https://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lDt-6WMp0qh3" + }, + "source": [ + "collect_attributions = True # Set to False for no attributions\n", + "\n", + "# Setup basic transforms\n", + "# The model has the normalization step in its internal transform_input\n", + "# function, so we don't need to normalize our inputs here.\n", + "transform_list = [\n", + " torchvision.transforms.Resize((224, 224)),\n", + " torchvision.transforms.ToTensor(),\n", + "]\n", + "transform_list = torchvision.transforms.Compose(transform_list)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i85yBIhL7owj" + }, + "source": [ + "To make it easier to load the ImageNet dataset, we can use [Torchvision](https://pytorch.org/vision/stable/datasets.html#imagenet)'s `torchvision.datasets.ImageNet` instead of the default `ImageFolder`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3oRqxlMq7gJ4" + }, + "source": [ + "# Load the dataset\n", + "image_dataset = torchvision.datasets.ImageNet(\n", + " root=\"path/to/dataset\", split=\"train\", transform=transform_list\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "573290Fr8KN7" + }, + "source": [ + "Now we wrap our dataset in a `torch.utils.data.DataLoader` instance, and set the desired batch size." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "DUCfwsvR7iGC" + }, + "source": [ + "# Set desired batch size & load dataset with torch.utils.DataLoader\n", + "image_loader = torch.utils.data.DataLoader(\n", + " image_dataset,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4qfpBAPu18jv" + }, + "source": [ + "We load our model, then set the desired model target layers and corresponding file names." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qMViqsJ82Mcp" + }, + "source": [ + "# Model to collect samples from, what layers of the model to collect samples from,\n", + "# and the desired names to use for the target layers.\n", + "sample_model = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + ")\n", + "sample_targets = [sample_model.mixed4c_relu]\n", + "sample_target_names = [\"mixed4c_relu_samples\"]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jl719nyZEGSt" + }, + "source": [ + "By default the activation samples will not have the right class attributions, so we remedy this by loading a second instance of our model. We then replace all `nn.MaxPool2d` layers in the second model instance with Captum's `MaxPool2dRelaxed` layer. The relaxed max pooling layer lets us estimate the sample class attributions by determining the rate at which increasing the neuron affects the output classes." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "A-VJyHRm1tqC" + }, + "source": [ + "# Optionally collect attributions from a copy of the first model that's\n", + "# been setup with relaxed pooling layers.\n", + "if collect_attributions:\n", + " sample_model_attr = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + " )\n", + " opt.models.replace_layers(\n", + " sample_model_attr,\n", + " torch.nn.MaxPool2d,\n", + " opt.models.MaxPool2dRelaxed,\n", + " transfer_vars=True,\n", + " )\n", + " sample_attr_targets = [sample_model_attr.mixed4c_relu]\n", + " sample_logit_target = sample_model_attr.fc\n", + "else:\n", + " sample_model_attr = None\n", + " sample_attr_targets = None\n", + " sample_logit_target = None" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "32zDGSR5-qDW" + }, + "source": [ + "With our dataset loaded and models ready to go, we can now start collecting our samples. To perform the sample collection, we define a function called `capture_activation_samples` to randomly sample an x and y position for every image for all specified target layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2YLBCYP0J4Gq" + }, + "source": [ + "def attribute_spatial_position(\n", + " target_activ: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_mask: torch.Tensor,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " This function employs the double backward trick in order to perform\n", + " forward-mode AD.\n", + "\n", + " See here for more details:\n", + " https://github.com/renmengye/tensorflow-forward-ad/issues/2\n", + "\n", + " Based on the Collect Activations Lucid tutorial:\n", + " https://colab.research.google.com/github/tensorflow\n", + " /lucid/blob/master/notebooks/activation-atlas/activation-atlas-collect.ipynb\n", + "\n", + " Args:\n", + "\n", + " logit_activ: Captured activations from the FC / logit layer.\n", + " target_activ: Captured activations from the target layer.\n", + " position_mask (torch.Tensor, optional): If using a batch size greater than\n", + " one, a mask is used to zero out all the non-target positions.\n", + "\n", + " Returns:\n", + " logit_attr (torch.Tensor): A sorted list of class attributions for the target\n", + " spatial positions.\n", + " \"\"\"\n", + "\n", + " assert target_activ.dim() == 2 or target_activ.dim() == 4\n", + " assert logit_activ.dim() == 2\n", + "\n", + " zeros = torch.nn.Parameter(torch.zeros_like(logit_activ))\n", + " target_zeros = target_activ * position_mask\n", + "\n", + " grad_one = torch.autograd.grad(\n", + " outputs=[logit_activ],\n", + " inputs=[target_activ],\n", + " grad_outputs=[zeros],\n", + " create_graph=True,\n", + " )\n", + " logit_attr = torch.autograd.grad(\n", + " outputs=grad_one,\n", + " inputs=[zeros],\n", + " grad_outputs=[target_zeros],\n", + " create_graph=True,\n", + " )[0]\n", + " return logit_attr\n", + "\n", + "\n", + "def capture_activation_samples(\n", + " loader: torch.utils.data.DataLoader,\n", + " model: torch.nn.Module,\n", + " targets: List[torch.nn.Module],\n", + " target_names: Optional[List[str]] = None,\n", + " sample_dir: str = \"\",\n", + " num_images: Optional[int] = None,\n", + " samples_per_image: int = 1,\n", + " input_device: torch.device = torch.device(\"cpu\"),\n", + " collect_attributions: bool = False,\n", + " attr_model: Optional[torch.nn.Module] = None,\n", + " attr_targets: Optional[List[torch.nn.Module]] = None,\n", + " logit_target: Optional[torch.nn.Module] = None,\n", + " show_progress: bool = False,\n", + "):\n", + " \"\"\"\n", + " Capture randomly sampled activations & optional attributions for those samples,\n", + " for an image dataset from one or more target layers.\n", + "\n", + " Samples are saved to files for speed, memory efficient, and to preserve them in\n", + " the event of any crashes.\n", + "\n", + " Based on the Collect Activations Lucid tutorial:\n", + " https://colab.research.google.com/github/tensorflow\n", + " /lucid/blob/master/notebooks/activation-atlas/activation-atlas-collect.ipynb\n", + "\n", + " Args:\n", + "\n", + " loader (torch.utils.data.DataLoader): A torch.utils.data.DataLoader\n", + " instance for an image dataset.\n", + " model (nn.Module): A PyTorch model instance.\n", + " targets (list of nn.Module): A list of layers to collect activation samples\n", + " from.\n", + " target_names (list of str, optional): A list of names to use when saving sample\n", + " tensors as files. Names will automatically be chosen if set to None.\n", + " Default: None\n", + " sample_dir (str): Path to where activation samples should be saved.\n", + " Default: \"\"\n", + " num_images (int, optional): How many images to collect samples from.\n", + " Default is to collect samples for every image in the dataset. Set to None\n", + " to collect samples from every image in the dataset.\n", + " Default: None\n", + " samples_per_image (int): How many samples to collect per image.\n", + " Default: 1\n", + " input_device (torch.device, optional): The device to use for model\n", + " inputs.\n", + " Default: torch.device(\"cpu\")\n", + " collect_attributions (bool, optional): Whether or not to collect attributions\n", + " for samples.\n", + " Default: False\n", + " attr_model (nn.Module, optional): A PyTorch model instance to use for\n", + " calculating sample attributions.\n", + " Default: None\n", + " attr_targets (list of nn.Module, optional): A list of attribution model layers\n", + " to collect attributions from. This should be the exact same as the targets\n", + " parameter, except for the attribution model.\n", + " Default: None\n", + " logit_target (nn.Module, optional): The final layer in the attribution model\n", + " that determines the classes. This parameter is only enabled if\n", + " collect_attributions is set to True.\n", + " Default: None\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + " \"\"\"\n", + "\n", + " if target_names is None:\n", + " target_names = [\"target\" + str(i) + \"_\" for i in range(len(targets))]\n", + "\n", + " assert len(target_names) == len(targets)\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " def random_sample(\n", + " activations: torch.Tensor,\n", + " ) -> Tuple[List[torch.Tensor], List[List[List[int]]]]:\n", + " \"\"\"\n", + " Randomly sample H & W dimensions of activations with 4 dimensions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " activation_samples: List = []\n", + " position_list: List = []\n", + "\n", + " with torch.no_grad():\n", + " for i in range(samples_per_image):\n", + " sample_position_list: List = []\n", + " for b in range(activations.size(0)):\n", + " if activations.dim() == 4:\n", + " h, w = activations.shape[2:]\n", + " y = torch.randint(low=1, high=h - 1, size=[1])\n", + " x = torch.randint(low=1, high=w - 1, size=[1])\n", + " activ = activations[b, :, y, x]\n", + " sample_position_list.append((b, y, x))\n", + " elif activations.dim() == 2:\n", + " activ = activations[b].unsqueeze(1)\n", + " sample_position_list.append(b)\n", + " activation_samples.append(activ)\n", + " position_list.append(sample_position_list)\n", + " return activation_samples, position_list\n", + "\n", + " def attribute_samples(\n", + " activations: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_list: List[List[List[int]]],\n", + " ) -> List[torch.Tensor]:\n", + " \"\"\"\n", + " Collect attributions for target sample positions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " sample_attributions: List = []\n", + " with torch.set_grad_enabled(True):\n", + " zeros_mask = torch.zeros_like(activations)\n", + " for sample_pos_list in position_list:\n", + " for c in sample_pos_list:\n", + " if activations.dim() == 4:\n", + " zeros_mask[c[0], :, c[1], c[2]] = 1\n", + " elif activations.dim() == 2:\n", + " zeros_mask[c] = 1\n", + " attr = attribute_spatial_position(\n", + " activations, logit_activ, position_mask=zeros_mask\n", + " ).detach()\n", + " sample_attributions.append(attr)\n", + " return sample_attributions\n", + "\n", + " if collect_attributions:\n", + " logit_target == list(model.children())[len(list(model.children())) - 1 :][\n", + " 0\n", + " ] if logit_target is None else logit_target\n", + " attr_targets = cast(List[torch.nn.Module], attr_targets)\n", + " attr_targets += [cast(torch.nn.Module, logit_target)]\n", + "\n", + " if show_progress:\n", + " total = (\n", + " len(loader.dataset) if num_images is None else num_images # type: ignore\n", + " )\n", + " pbar = tqdm(total=total, unit=\" images\")\n", + "\n", + " image_count, batch_count = 0, 0\n", + " with torch.no_grad():\n", + " for inputs, _ in loader:\n", + " inputs = inputs.to(input_device)\n", + " image_count += inputs.size(0)\n", + " batch_count += 1\n", + "\n", + " target_activ_dict = opt.models.collect_activations(model, targets, inputs)\n", + " if collect_attributions:\n", + " with torch.set_grad_enabled(True):\n", + " target_activ_attr_dict = opt.models.collect_activations(\n", + " attr_model, attr_targets, inputs\n", + " )\n", + " logit_activ = target_activ_attr_dict[logit_target]\n", + " del target_activ_attr_dict[logit_target]\n", + "\n", + " sample_coords = []\n", + " for t, n in zip(target_activ_dict, target_names):\n", + " sample_tensors, p_list = random_sample(target_activ_dict[t])\n", + " torch.save(\n", + " sample_tensors,\n", + " os.path.join(\n", + " sample_dir, n + \"_activations_\" + str(batch_count) + \".pt\"\n", + " ),\n", + " )\n", + " sample_coords.append(p_list)\n", + "\n", + " if collect_attributions:\n", + " for t, n, s_coords in zip(\n", + " target_activ_attr_dict, target_names, sample_coords\n", + " ):\n", + " sample_attrs = attribute_samples(\n", + " target_activ_attr_dict[t], logit_activ, s_coords\n", + " )\n", + " torch.save(\n", + " sample_attrs,\n", + " os.path.join(\n", + " sample_dir,\n", + " n + \"_attributions_\" + str(batch_count) + \".pt\",\n", + " ),\n", + " )\n", + "\n", + " if show_progress:\n", + " pbar.update(inputs.size(0))\n", + "\n", + " if num_images is not None:\n", + " if image_count > num_images:\n", + " break\n", + "\n", + " if show_progress:\n", + " pbar.close()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IWsmPssJJ09E" + }, + "source": [ + "We now collect our activation samples and attribution, as we iterate through our image dataset. Note that this step can be rather time consuming depending on the image dataset being used." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uODdkyjY1lap" + }, + "source": [ + "# Directory to save sample files to\n", + "sample_dir = \"inceptionv1_samples\"\n", + "try:\n", + " os.mkdir(sample_dir)\n", + "except:\n", + " pass\n", + "\n", + "# Collect samples & optionally attributions as well\n", + "capture_activation_samples(\n", + " loader=image_loader,\n", + " model=sample_model,\n", + " targets=sample_targets,\n", + " target_names=sample_target_names,\n", + " attr_model=sample_model_attr,\n", + " attr_targets=sample_attr_targets,\n", + " input_device=device,\n", + " sample_dir=sample_dir,\n", + " show_progress=True,\n", + " collect_attributions=collect_attributions,\n", + " logit_target=sample_logit_target,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eMrBUaPi97fF" + }, + "source": [ + "Now that we've collected our samples, we need to combine them into a single tensor. Below we use the `consolidate_samples` function to load each list of tensor samples, and then concatinate them into a single tensor." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "LaFglPVYKbXj" + }, + "source": [ + "def consolidate_samples(\n", + " sample_dir: str,\n", + " sample_basename: str = \"\",\n", + " dim: int = 1,\n", + " num_files: Optional[int] = None,\n", + " show_progress: bool = False,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Combine samples collected from capture_activation_samples into a single tensor\n", + " with a shape of [n_target_classes, n_samples].\n", + "\n", + " Args:\n", + "\n", + " sample_dir (str): The directory where activation samples where saved.\n", + " sample_basename (str, optional): If samples from different layers are present\n", + " in sample_dir, then you can use samples from only a specific layer by\n", + " specifying the basename that samples of the same layer share.\n", + " Default: \"\"\n", + " dim (int, optional): The dimension to concatinate the samples together on.\n", + " Default: 1\n", + " num_files (int, optional): The number of sample files that you wish to\n", + " concatinate together, if you do not wish to concatinate all of them.\n", + " Default: None\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + "\n", + " Returns:\n", + " sample_tensor (torch.Tensor): A tensor containing all the specified sample\n", + " tensors with a shape of [n_target_classes, n_samples].\n", + " \"\"\"\n", + "\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " tensor_samples = [\n", + " os.path.join(sample_dir, name)\n", + " for name in os.listdir(sample_dir)\n", + " if sample_basename.lower() in name.lower()\n", + " and os.path.isfile(os.path.join(sample_dir, name))\n", + " ]\n", + " assert len(tensor_samples) > 0\n", + "\n", + " if show_progress:\n", + " total = len(tensor_samples) if num_files is None else num_files # type: ignore\n", + " pbar = tqdm(total=total, unit=\" sample batches collected\")\n", + "\n", + " samples: List[torch.Tensor] = []\n", + " for file in tensor_samples:\n", + " sample_batch = torch.load(file)\n", + " for s in sample_batch:\n", + " samples += [s.cpu()]\n", + " if show_progress:\n", + " pbar.update(1)\n", + "\n", + " if show_progress:\n", + " pbar.close()\n", + " return torch.cat(samples, dim)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BKUPszVR1Ew-" + }, + "source": [ + "# Combine our newly collected samples into single tensors.\n", + "# We load the sample tensors from sample_dir and then\n", + "# concatenate them.\n", + "\n", + "for name in sample_target_names:\n", + " print(\"Combining \" + name + \" samples:\")\n", + " activation_samples = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_activations\",\n", + " dim=1,\n", + " show_progress=True,\n", + " )\n", + " if collect_attributions:\n", + " sample_attributions = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_attributions\",\n", + " dim=0,\n", + " show_progress=True,\n", + " )\n", + "\n", + " # Save the results\n", + " torch.save(activation_samples, name + \"activation_samples.pt\")\n", + " if collect_attributions:\n", + " torch.save(sample_attributions, name + \"attribution_samples.pt\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dQig3atsa9UI" + }, + "source": [ + "Now that we have successfully collected the required sample activations & attributions, we can move onto the main Activation Atlas and Class Activation Atlas tutorials!" + ] + } + ] +} \ No newline at end of file