Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/attr/layer/test_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch
from captum._utils.typing import TensorLikeList
from captum.attr._core.layer.grad_cam import LayerGradCam
from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorTuplesAlmostEqual
from tests.helpers.basic_models import (
BasicModel_ConvNet_One_Conv,
BasicModel_MultiLayer,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/layer/test_layer_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from captum.attr import LayerLRP
from captum.attr._utils.lrp_rules import Alpha1_Beta0_Rule, EpsilonRule, GammaRule

from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
from torch import Tensor

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/neuron/test_neuron_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel_ConvNet_One_Conv,
BasicModel_MultiLayer,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/neuron/test_neuron_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.layer.layer_conductance import LayerConductance
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel_ConvNet,
BasicModel_MultiLayer,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/neuron/test_neuron_deeplift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
_create_inps_and_base_for_deeplift_neuron_layer_testing,
_create_inps_and_base_for_deepliftshap_neuron_layer_testing,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel_ConvNet,
BasicModel_ConvNet_MaxPool3d,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/neuron/test_neuron_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from captum.attr._core.neuron.neuron_integrated_gradients import (
NeuronIntegratedGradients,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_MultiLayer
from tests.helpers.classification_models import SoftmaxModel
from torch import Tensor
Expand Down
2 changes: 1 addition & 1 deletion tests/attr/test_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from captum.attr._utils.baselines import ProductBaselines

# from parameterized import parameterized
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest


class TestProductBaselines(BaseTest):
Expand Down
2 changes: 1 addition & 1 deletion tests/attr/test_class_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from captum.attr import ClassSummarizer, CommonStats
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest


class Test(BaseTest):
Expand Down
2 changes: 1 addition & 1 deletion tests/attr/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from captum.attr._core.noise_tunnel import SUPPORTED_NOISE_TUNNEL_TYPES
from captum.attr._utils.common import _validate_input, _validate_noise_tunnel_type
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest


class Test(BaseTest):
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import (
NeuronDeconvolution,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv
from torch.nn import Module

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._utils.attribution import Attribution
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel,
BasicModel_ConvNet_One_Conv,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch
from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModelWithSparseInputs
from torch import Tensor

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from captum.attr._core.gradient_shap import GradientShap
from captum.attr._core.integrated_gradients import IntegratedGradients
from numpy import ndarray
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicLinearModel, BasicModel2
from tests.helpers.classification_models import SoftmaxModel

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_guided_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import (
NeuronGuidedBackprop,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv
from torch.nn import Module

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_guided_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.attr._core.guided_grad_cam import GuidedGradCam
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv
from torch import Tensor
from torch.nn import Module
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_hook_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
should_create_generated_test,
)
from tests.attr.helpers.test_config import config
from tests.helpers.basic import BaseTest, deep_copy_args
from tests.helpers import BaseTest
from tests.helpers.basic import deep_copy_args
from torch.nn import Module

"""
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_input_x_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from captum.attr._core.input_x_gradient import InputXGradient
from captum.attr._core.noise_tunnel import NoiseTunnel
from tests.attr.test_saliency import _get_basic_config, _get_multiargs_basic_config
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.classification_models import SoftmaxModel
from torch import Tensor
from torch.nn import Module
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_integrated_gradients_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._utils.common import _tensorize_baseline
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel,
BasicModel2,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_integrated_gradients_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from captum._utils.typing import BaselineType, Tensor
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.noise_tunnel import NoiseTunnel
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.classification_models import SigmoidModel, SoftmaxModel
from torch.nn import Module

Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch import Tensor


Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from captum.attr._core.shapley_value import ShapleyValueSampling
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized, parameterized_class
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch import nn, Tensor


Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
GammaRule,
IdentityRule,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel_ConvNet_One_Conv,
BasicModel_MultiLayer,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.occlusion import Occlusion
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import (
BasicModel3,
BasicModel_ConvNet_One_Conv,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual


def get_values(n: int = 100, lo=None, hi=None, integers: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion tests/attr/test_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import torch
from captum.attr import CommonStats, Summarizer
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest


class Test(BaseTest):
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_utils_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
_batched_operator,
_tuple_splice_range,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual


class Test(BaseTest):
Expand Down
2 changes: 1 addition & 1 deletion tests/concept/test_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from captum.concept._core.concept import Concept
from captum.concept._utils.data_iterator import dataset_to_dataloader
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest
from torch.utils.data import IterableDataset


Expand Down
3 changes: 2 additions & 1 deletion tests/concept/test_tcav.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from captum.concept._utils.classifier import Classifier
from captum.concept._utils.common import concepts_to_str
from captum.concept._utils.data_iterator import dataset_to_dataloader
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.basic_models import BasicModel_ConvNet
from torch import Tensor
from torch.utils.data import DataLoader, IterableDataset
Expand Down
11 changes: 11 additions & 0 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3

try:
from tests.helpers.fb.internal_base import FbBaseTest as BaseTest

__all__ = [
"BaseTest",
]

except ImportError:
from tests.helpers.basic import BaseTest
1 change: 1 addition & 0 deletions tests/helpers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import random
import unittest

from typing import Callable

import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_arnoldi_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
_unflatten_params_factory,
)
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
7 changes: 2 additions & 5 deletions tests/influence/_core/test_naive_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
_unflatten_params_factory,
)
from parameterized import parameterized
from tests.helpers.basic import (
assertTensorAlmostEqual,
assertTensorTuplesAlmostEqual,
BaseTest,
)
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_similarity_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
euclidean_distance,
SimilarityInfluence,
)
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch.utils.data import Dataset


Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_tracin_intermediate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_tracin_k_most_influential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from captum.influence._core.tracincp import TracInCP

from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_tracin_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_isSorted,
_wrap_model_in_dataparallel,
Expand Down
3 changes: 2 additions & 1 deletion tests/influence/_core/test_tracin_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from captum.influence._core.tracincp import TracInCP, TracInCPBase
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
Expand Down
2 changes: 1 addition & 1 deletion tests/influence/_core/test_tracin_show_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from captum.influence._core.tracincp import TracInCP
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast
from parameterized import parameterized
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest
from tests.influence._utils.common import (
build_test_name_func,
DataInfluenceConstructor,
Expand Down
2 changes: 1 addition & 1 deletion tests/influence/_core/test_tracin_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast

from parameterized import parameterized
from tests.helpers.basic import BaseTest
from tests.helpers import BaseTest
from tests.influence._utils.common import (
build_test_name_func,
DataInfluenceConstructor,
Expand Down
Loading