Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captum/_utils/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import captum._utils.common as common
import torch
from captum.attr import LayerActivation
from captum.attr._core.layer.layer_activation import LayerActivation
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
Expand Down
20 changes: 0 additions & 20 deletions captum/_utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,5 @@
from captum._utils.models.linear_model import (
LinearModel,
SGDLasso,
SGDLinearModel,
SGDLinearRegression,
SGDRidge,
SkLearnLasso,
SkLearnLinearModel,
SkLearnLinearRegression,
SkLearnRidge,
)
from captum._utils.models.model import Model

__all__ = [
"Model",
"LinearModel",
"SGDLinearModel",
"SGDLasso",
"SGDRidge",
"SGDLinearRegression",
"SkLearnLinearModel",
"SkLearnLasso",
"SkLearnRidge",
"SkLearnLinearRegression",
]
52 changes: 52 additions & 0 deletions tests/utils/evaluate_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from typing import cast, Dict

import torch
from torch import Tensor


def evaluate(test_data, classifier) -> Dict[str, Tensor]:
classifier.eval()

l1_loss = 0.0
l2_loss = 0.0
n = 0
l2_losses = []
with torch.no_grad():
for data in test_data:
if len(data) == 2:
x, y = data
w = None
else:
x, y, w = data

out = classifier(x)

y = y.view(x.shape[0], -1)
assert y.shape == out.shape

if w is None:
l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64)
l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64)
l2_losses.append(((out - y) ** 2).to(dtype=torch.float64))
else:
l1_loss += (
(w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64)
)
l2_loss += (
(w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64)
)
l2_losses.append(
(w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64)
)

n += x.shape[0]

l2_losses = torch.cat(l2_losses, dim=0)
assert n > 0

# just to double check
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all()

classifier.train()
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)}
10 changes: 5 additions & 5 deletions tests/utils/models/linear_models/_test_linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import sklearn.datasets as datasets
import torch
from tests.utils.test_linear_model import _evaluate
from tests.utils.evaluate_linear_model import evaluate
from torch.utils.data import DataLoader, TensorDataset


Expand Down Expand Up @@ -80,11 +80,11 @@ def compare_to_sk_learn(
alpha=alpha,
)

sklearn_stats.update(_evaluate(val_loader, sklearn_classifier))
pytorch_stats.update(_evaluate(val_loader, pytorch_classifier))
sklearn_stats.update(evaluate(val_loader, sklearn_classifier))
pytorch_stats.update(evaluate(val_loader, pytorch_classifier))

train_stats_pytorch = _evaluate(train_loader, pytorch_classifier)
train_stats_sklearn = _evaluate(train_loader, sklearn_classifier)
train_stats_pytorch = evaluate(train_loader, pytorch_classifier)
train_stats_sklearn = evaluate(train_loader, sklearn_classifier)

o_pytorch = {"l2": train_stats_pytorch["l2"]}
o_sklearn = {"l2": train_stats_sklearn["l2"]}
Expand Down
52 changes: 3 additions & 49 deletions tests/utils/test_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

from typing import cast, Dict, Optional, Union
from typing import Optional, Union

import torch
from captum._utils.models.linear_model.model import (
Expand All @@ -10,56 +10,10 @@
)
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.utils.evaluate_linear_model import evaluate
from torch import Tensor


def _evaluate(test_data, classifier) -> Dict[str, Tensor]:
classifier.eval()

l1_loss = 0.0
l2_loss = 0.0
n = 0
l2_losses = []
with torch.no_grad():
for data in test_data:
if len(data) == 2:
x, y = data
w = None
else:
x, y, w = data

out = classifier(x)

y = y.view(x.shape[0], -1)
assert y.shape == out.shape

if w is None:
l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64)
l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64)
l2_losses.append(((out - y) ** 2).to(dtype=torch.float64))
else:
l1_loss += (
(w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64)
)
l2_loss += (
(w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64)
)
l2_losses.append(
(w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64)
)

n += x.shape[0]

l2_losses = torch.cat(l2_losses, dim=0)
assert n > 0

# just to double check
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all()

classifier.train()
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)}


class TestLinearModel(BaseTest):
MAX_POINTS: int = 3

Expand Down Expand Up @@ -100,7 +54,7 @@ def train_and_compare(

self.assertTrue(model.bias() is not None if bias else model.bias() is None)

l2_loss = _evaluate(train_loader, model)["l2"]
l2_loss = evaluate(train_loader, model)["l2"]

if objective == "lasso":
reg = model.representation().norm(p=1).view_as(l2_loss)
Expand Down