Skip to content

Bug in LimeBase example? #905

@th789

Description

@th789

🐛 Bug

I am trying to run the example provided for LimeBase on https://captum.ai/api/lime.html. However, running the example leads to the following error message: "TypeError: () got an unexpected keyword argument 'kernel_width'" (more info below).

To Reproduce

import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel

class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 3)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

###all code below copied from example provided
net = SimpleClassifier()

def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
      # kernel_width will be provided to attribute as a kwarg
      kernel_width = kwargs["kernel_width"]
      l2_dist = torch.norm(original_input - perturbed_input)
      return torch.exp(- (l2_dist**2) / (kernel_width**2))

def perturb_func(original_input, **kwargs):
      return original_input + torch.randn_like(original_input)

input = torch.randn(2, 5)

lime_attr = LimeBase(net,
                     SkLearnLinearModel("linear_model.Ridge"),
                     similarity_func=similarity_kernel,
                     perturb_func=perturb_func,
                     perturb_interpretable_space=False,
                     from_interp_rep_transform=None,
                     to_interp_rep_transform=lambda x: x)

attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) #error message appears after running this line

Error message:
Screen Shot 2022-03-19 at 1 18 04 PM

Expected behavior

There should be no error message and attr_coefs should return the feature attributions.

Environment

Describe the environment used for Captum

 - Captum version: 0.5.0
 - Pytorch version: 1.10.0+cu111
 - OS (e.g., Linux): macOS
 - How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
 - Python version: 3.7.12

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions