Skip to content

Commit b6e8670

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Default enable_cross_tensor_attribution to true (#1643)
Summary: Internal usage: - PyPer/APS/MVAI FI at 100% since 4/29 - Fluent2 FI at 100% since 5/22 - NI at 100% since 7/3 Differential Revision: D81948483
1 parent c9fc87d commit b6e8670

File tree

4 files changed

+6
-7
lines changed

4 files changed

+6
-7
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def attribute(
114114
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
115115
perturbations_per_eval: int = 1,
116116
show_progress: bool = False,
117-
enable_cross_tensor_attribution: bool = False,
117+
enable_cross_tensor_attribution: bool = True,
118118
**kwargs: Any,
119119
) -> TensorOrTupleOfTensorsGeneric:
120120
r"""
@@ -744,7 +744,7 @@ def attribute_future(
744744
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
745745
perturbations_per_eval: int = 1,
746746
show_progress: bool = False,
747-
enable_cross_tensor_attribution: bool = False,
747+
enable_cross_tensor_attribution: bool = True,
748748
**kwargs: Any,
749749
) -> Future[TensorOrTupleOfTensorsGeneric]:
750750
r"""

captum/attr/_core/feature_permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def attribute( # type: ignore
115115
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
116116
perturbations_per_eval: int = 1,
117117
show_progress: bool = False,
118-
enable_cross_tensor_attribution: bool = False,
118+
enable_cross_tensor_attribution: bool = True,
119119
**kwargs: Any,
120120
) -> TensorOrTupleOfTensorsGeneric:
121121
r"""
@@ -304,7 +304,7 @@ def attribute_future(
304304
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
305305
perturbations_per_eval: int = 1,
306306
show_progress: bool = False,
307-
enable_cross_tensor_attribution: bool = False,
307+
enable_cross_tensor_attribution: bool = True,
308308
**kwargs: Any,
309309
) -> Future[TensorOrTupleOfTensorsGeneric]:
310310
"""

captum/attr/_core/occlusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def attribute( # type: ignore
267267
shift_counts=tuple(shift_counts),
268268
strides=strides,
269269
show_progress=show_progress,
270-
enable_cross_tensor_attribution=True,
271270
)
272271

273272
def attribute_future(self) -> None:

tests/attr/test_feature_permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def forward_func(x: Tensor) -> Tensor:
120120
inp = torch.tensor([[1.0, 2.0]])
121121
assertTensorAlmostEqual(
122122
self,
123-
feature_importance.attribute(inp),
123+
feature_importance.attribute(inp, enable_cross_tensor_attribution=False),
124124
torch.tensor([[0.0, 0.0]]),
125125
delta=0.0,
126126
)
127127

128128
feature_importance._min_examples_per_batch = 1
129129
with self.assertRaises(AssertionError):
130-
feature_importance.attribute(inp)
130+
feature_importance.attribute(inp, enable_cross_tensor_attribution=False)
131131

132132
def test_simple_input_with_min_examples_in_group(self) -> None:
133133
def forward_func(x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)