Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
print("The Pillow/PIL library is required to use Captum's Optim library")

from captum.optim._param.image.transform import SymmetricPadding, ToRGB
from captum.optim._utils.typing import InitSize, SquashFunc
from captum.optim._utils.typing import SquashFuncType


class ImageTensor(torch.Tensor):
Expand Down Expand Up @@ -183,7 +183,7 @@ class FFTImage(ImageParameterization):

def __init__(
self,
size: InitSize = None,
size: Tuple[int, int] = None,
channels: int = 3,
batch: int = 1,
init: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -271,7 +271,7 @@ def forward(self) -> torch.Tensor:
class PixelImage(ImageParameterization):
def __init__(
self,
size: InitSize = None,
size: Tuple[int, int] = None,
channels: int = 3,
batch: int = 1,
init: Optional[torch.Tensor] = None,
Expand All @@ -292,7 +292,7 @@ def forward(self) -> torch.Tensor:
class LaplacianImage(ImageParameterization):
def __init__(
self,
size: InitSize = None,
size: Tuple[int, int] = None,
channels: int = 3,
batch: int = 1,
init: Optional[torch.Tensor] = None,
Expand All @@ -318,7 +318,7 @@ def __init__(

def setup_input(
self,
size: InitSize,
size: Tuple[int, int],
channels: int,
power: float = 0.1,
init: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -470,49 +470,65 @@ class NaturalImage(ImageParameterization):
r"""Outputs an optimizable input image.

By convention, single images are CHW and float32s in [0,1].
The underlying parameterization is decorrelated via a ToRGB transform.
The underlying parameterization can be decorrelated via a ToRGB transform.
When used with the (default) FFT parameterization, this results in a fully
uncorrelated image parameterization. :-)

If a model requires a normalization step, such as normalizing imagenet RGB values,
or rescaling to [0,255], it can perform those steps with the provided transforms or
inside its computation.
For example, our GoogleNet factory function has a `transform_input=True` argument.

Arguments:
size (Tuple[int, int]): The height and width to use for the nn.Parameter image
tensor.
channels (int): The number of channels to use when creating the
nn.Parameter tensor.
batch (int): The number of channels to use when creating the
nn.Parameter tensor, or stacking init images.
parameterization (ImageParameterization): An image parameterization class.
squash_func (SquashFuncType): The squash function to use after
color recorrelation. A funtion or lambda function.
decorrelation_module (nn.Module): A ToRGB instance.
decorrelate_init (bool): Whether or not to apply color decorrelation to the
init tensor input.
"""

def __init__(
self,
size: InitSize = None,
size: Tuple[int, int] = None,
channels: int = 3,
batch: int = 1,
parameterization: ImageParameterization = FFTImage,
init: Optional[torch.Tensor] = None,
parameterization: ImageParameterization = FFTImage,
squash_func: Optional[SquashFuncType] = None,
decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
decorrelate_init: bool = True,
squash_func: Optional[SquashFunc] = None,
) -> None:
super().__init__()
self.decorrelate = ToRGB(transform_name="klt")
self.decorrelate = decorrelation_module
if init is not None:
assert init.dim() == 3 or init.dim() == 4
if decorrelate_init:
assert self.decorrelate is not None
init = (
init.refine_names("B", "C", "H", "W")
if init.dim() == 4
else init.refine_names("C", "H", "W")
)
init = self.decorrelate(init, inverse=True).rename(None)
if squash_func is None:
squash_func: SquashFunc = lambda x: x.clamp(0, 1)
squash_func: SquashFuncType = lambda x: x.clamp(0, 1)
else:
if squash_func is None:
squash_func: SquashFunc = lambda x: torch.sigmoid(x)
squash_func: SquashFuncType = lambda x: torch.sigmoid(x)
self.squash_func = squash_func
self.parameterization = parameterization(
size=size, channels=channels, batch=batch, init=init
)

def forward(self) -> torch.Tensor:
image = self.parameterization()
image = self.decorrelate(image)
if self.decorrelate is not None:
image = self.decorrelate(image)
image = image.rename(None) # TODO: the world is not yet ready
return CudaImageTensor(self.squash_func(image))
125 changes: 76 additions & 49 deletions captum/optim/_param/image/transform.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import math
import numbers
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from captum.optim._utils.image.common import nchannels_to_rgb
from captum.optim._utils.typing import TransformSize, TransformVal, TransformValList
from captum.optim._utils.typing import (
IntSeqOrIntType,
NumOrTensorType,
NumSeqOrTensorType,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -46,14 +50,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class ToRGB(nn.Module):
"""Transforms arbitrary channels to RGB. We use this to ensure our
image parameteriaztion itself can be decorrelated. So this goes between
the image parameterization and the normalization/sigmoid step.
We offer two transforms: Karhunen-Loève (KLT) and I1I2I3.
image parametrization itself can be decorrelated. So this goes between
the image parametrization and the normalization/sigmoid step.
We offer two precalculated transforms: Karhunen-Loève (KLT) and I1I2I3.
KLT corresponds to the empirically measured channel correlations on imagenet.
I1I2I3 corresponds to an aproximation for natural images from Ohta et al.[0]
I1I2I3 corresponds to an approximation for natural images from Ohta et al.[0]
[0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation,"
Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980
https://www.sciencedirect.com/science/article/pii/0146664X80900477

Arguments:
transform (str or tensor): Either a string for one of the precalculated
transform matrices, or a 3x3 matrix for the 3 RGB channels of input
tensors.
"""

@staticmethod
Expand All @@ -73,15 +82,21 @@ def i1i2i3_transform() -> torch.Tensor:
]
return torch.Tensor(i1i2i3_matrix)

def __init__(self, transform_name: str = "klt") -> None:
def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None:
super().__init__()

if transform_name == "klt":
assert isinstance(transform, str) or torch.is_tensor(transform)
if torch.is_tensor(transform):
transform = cast(torch.Tensor, transform)
assert list(transform.shape) == [3, 3]
self.register_buffer("transform", transform)
elif transform == "klt":
self.register_buffer("transform", ToRGB.klt_transform())
elif transform_name == "i1i2i3":
elif transform == "i1i2i3":
self.register_buffer("transform", ToRGB.i1i2i3_transform())
else:
raise ValueError("transform_name has to be either 'klt' or 'i1i2i3'")
raise ValueError(
"transform has to be either 'klt', 'i1i2i3'," + " or a matrix tensor."
)

def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
assert x.dim() == 3 or x.dim() == 4
Expand Down Expand Up @@ -118,60 +133,72 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:

class CenterCrop(torch.nn.Module):
"""
Center crop the specified amount of pixels from the edges.
Center crop a specified amount from a tensor.
Arguments:
size (int, sequence) or (int): Number of pixels to center crop away.
size (int, sequence, int): Number of pixels to center crop away.
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
"""

def __init__(self, size: TransformSize = 0) -> None:
def __init__(
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
) -> None:
super(CenterCrop, self).__init__()
if type(size) is list or type(size) is tuple:
assert len(size) == 2, (
"CenterCrop requires a single crop value or a tuple of (height,width)"
+ "in pixels for cropping."
)
self.crop_val = size
else:
self.crop_val = [size] * 2
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges

def forward(self, input: torch.Tensor) -> torch.Tensor:
assert (
input.dim() == 3 or input.dim() == 4
), "Input to CenterCrop must be 3D or 4D"
if input.dim() == 4:
h, w = input.size(2), input.size(3)
elif input.dim() == 3:
h, w = input.size(1), input.size(2)
h_crop = h - self.crop_val[0]
w_crop = w - self.crop_val[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
return input[..., sh : sh + h_crop, sw : sw + w_crop]
"""
Center crop an input.
Arguments:
input (torch.Tensor): Input to center crop.
Returns:
tensor (torch.Tensor): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)

def center_crop_shape(input: torch.Tensor, output_size: List[int]) -> torch.Tensor:

def center_crop(
input: torch.Tensor, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
) -> torch.Tensor:
"""
Crop NCHW & CHW outputs by specifying the desired output shape.
Center crop a specified amount from a tensor.
Arguments:
input (tensor): A CHW or NCHW image tensor to center crop.
size (int, sequence, int): Number of pixels to center crop away.
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
Returns:
*tensor*: A center cropped tensor.
"""

assert input.dim() == 4 or input.dim() == 3
output_size = [output_size] if not hasattr(output_size, "__iter__") else output_size
assert len(output_size) == 1 or len(output_size) == 2
output_size = output_size * 2 if len(output_size) == 1 else output_size
assert input.dim() == 3 or input.dim() == 4
crop_vals = [crop_vals] if not hasattr(crop_vals, "__iter__") else crop_vals
crop_vals = cast(Union[List[int], Tuple[int], Tuple[int, int]], crop_vals)
assert len(crop_vals) == 1 or len(crop_vals) == 2
crop_vals = crop_vals * 2 if len(crop_vals) == 1 else crop_vals

if input.dim() == 4:
h, w = input.size(2), input.size(3)
if input.dim() == 3:
h, w = input.size(1), input.size(2)

h_crop = h - int(round((h - output_size[0]) / 2.0))
w_crop = w - int(round((w - output_size[1]) / 2.0))

return input[
..., h_crop - output_size[0] : h_crop, w_crop - output_size[1] : w_crop
]
if pixels_from_edges:
h_crop = h - crop_vals[0]
w_crop = w - crop_vals[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x


def rand_select(transform_values: TransformValList) -> TransformVal:
def rand_select(transform_values: NumSeqOrTensorType) -> NumOrTensorType:
"""
Randomly return a value from the provided tuple or list
"""
Expand All @@ -186,19 +213,19 @@ class RandomScale(nn.Module):
scale (float, sequence): Tuple of rescaling values to randomly select from.
"""

def __init__(self, scale: TransformValList) -> None:
def __init__(self, scale: NumSeqOrTensorType) -> None:
super(RandomScale, self).__init__()
self.scale = scale

def get_scale_mat(
self, m: TransformVal, device: torch.device, dtype: torch.dtype
self, m: IntSeqOrIntType, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
scale_mat = torch.tensor(
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype
)
return scale_mat

def scale_tensor(self, x: torch.Tensor, scale: TransformVal) -> torch.Tensor:
def scale_tensor(self, x: torch.Tensor, scale: NumOrTensorType) -> torch.Tensor:
scale_matrix = self.get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat(
x.shape[0], 1, 1
)
Expand Down
19 changes: 11 additions & 8 deletions captum/optim/_utils/circuits.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn as nn

from captum.optim._param.image.transform import center_crop_shape
from captum.optim._param.image.transform import center_crop
from captum.optim._utils.models import collect_activations
from captum.optim._utils.typing import ModelInputType, TransformSize
from captum.optim._utils.typing import IntSeqOrIntType, ModelInputType


def get_expanded_weights(
model,
target1: nn.Module,
target2: nn.Module,
crop_shape: Optional[Union[Tuple[int, int], TransformSize]] = None,
crop_shape: Optional[Union[Tuple[int, int], IntSeqOrIntType]] = None,
model_input: ModelInputType = torch.zeros(1, 3, 224, 224),
crop_func: Optional[Callable] = center_crop,
) -> torch.Tensor:
"""
Extract meaningful weight interactions from between neurons which aren’t
Expand All @@ -32,6 +33,8 @@ def get_expanded_weights(
size to enter crop away padding.
model_input (tensor or tuple of tensors, optional): The input to use
with the specified model.
crop_func (Callable, optional): Specify a function to crop away the padding
from the output weights.
Returns:
*tensor*: A tensor containing the expanded weights in the form of:
(target2 output channels, target1 output channels, y, x)
Expand All @@ -56,8 +59,8 @@ def get_expanded_weights(
retain_graph=True,
)[0]
A.append(x.squeeze(0))
exapnded_weights = torch.stack(A, 0)
expanded_weights = torch.stack(A, 0)

if crop_shape is not None:
exapnded_weights = center_crop_shape(exapnded_weights, crop_shape)
return exapnded_weights
if crop_shape is not None and crop_func is not None:
expanded_weights = crop_func(expanded_weights, crop_shape)
return expanded_weights
9 changes: 4 additions & 5 deletions captum/optim/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ def cleanup(self):
LossFunction = Callable[[ModuleOutputMapping], Tensor]
SingleTargetLossFunction = Callable[[Tensor], Tensor]

InitSize = Tuple[int, int]
SquashFunc = Callable[[Tensor], Tensor]
TransformValList = Union[Sequence[int], Sequence[float], Tensor]
TransformVal = Union[int, float, Tensor]
TransformSize = Union[List[int], Tuple[int], int]
SquashFuncType = Callable[[Tensor], Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does [Tensor] mean a list of tensors ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's just how lambda functions are type hinted: Callable[[<in type>], <out type>].

NumSeqOrTensorType = Union[Sequence[int], Sequence[float], Tensor]
NumOrTensorType = Union[int, float, Tensor]
IntSeqOrIntType = Union[List[int], Tuple[int], Tuple[int, int], int]
ModelInputType = Union[Tuple[Tensor], Tensor]
Loading