Skip to content

Commit 993a6fd

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Default enable_cross_tensor_attribution to true (#1643)
Summary: Internal usage: - PyPer/APS/MVAI FI at 100% since 4/29 - Fluent2 FI at 100% since 5/22 - NI at 100% since 7/3 Differential Revision: D81948483
1 parent f815abc commit 993a6fd

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def attribute(
114114
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
115115
perturbations_per_eval: int = 1,
116116
show_progress: bool = False,
117-
enable_cross_tensor_attribution: bool = False,
117+
enable_cross_tensor_attribution: bool = True,
118118
**kwargs: Any,
119119
) -> TensorOrTupleOfTensorsGeneric:
120120
r"""
@@ -704,9 +704,11 @@ def _construct_ablated_input_across_tensors(
704704
tensor_mask.append(mask)
705705

706706
assert baseline is not None, "baseline must be provided"
707-
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
708-
1 - mask
709-
) + (baseline * mask.to(input_tensor.dtype))
707+
ablated_feature = input_tensor[start_idx:end_idx] * (1 - mask) + (
708+
baseline * mask.to(input_tensor.dtype)
709+
)
710+
ablated_input = ablated_input.to(ablated_feature.dtype)
711+
ablated_input[start_idx:end_idx] = ablated_feature
710712
current_masks.append(torch.stack(tensor_mask, dim=0))
711713
ablated_inputs.append(ablated_input)
712714

@@ -742,7 +744,7 @@ def attribute_future(
742744
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
743745
perturbations_per_eval: int = 1,
744746
show_progress: bool = False,
745-
enable_cross_tensor_attribution: bool = False,
747+
enable_cross_tensor_attribution: bool = True,
746748
**kwargs: Any,
747749
) -> Future[TensorOrTupleOfTensorsGeneric]:
748750
r"""

captum/attr/_core/feature_permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def attribute( # type: ignore
115115
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
116116
perturbations_per_eval: int = 1,
117117
show_progress: bool = False,
118-
enable_cross_tensor_attribution: bool = False,
118+
enable_cross_tensor_attribution: bool = True,
119119
**kwargs: Any,
120120
) -> TensorOrTupleOfTensorsGeneric:
121121
r"""
@@ -304,7 +304,7 @@ def attribute_future(
304304
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
305305
perturbations_per_eval: int = 1,
306306
show_progress: bool = False,
307-
enable_cross_tensor_attribution: bool = False,
307+
enable_cross_tensor_attribution: bool = True,
308308
**kwargs: Any,
309309
) -> Future[TensorOrTupleOfTensorsGeneric]:
310310
"""

captum/attr/_core/occlusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def attribute( # type: ignore
267267
shift_counts=tuple(shift_counts),
268268
strides=strides,
269269
show_progress=show_progress,
270-
enable_cross_tensor_attribution=True,
271270
)
272271

273272
def attribute_future(self) -> None:

tests/attr/test_feature_permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def forward_func(x: Tensor) -> Tensor:
120120
inp = torch.tensor([[1.0, 2.0]])
121121
assertTensorAlmostEqual(
122122
self,
123-
feature_importance.attribute(inp),
123+
feature_importance.attribute(inp, enable_cross_tensor_attribution=False),
124124
torch.tensor([[0.0, 0.0]]),
125125
delta=0.0,
126126
)
127127

128128
feature_importance._min_examples_per_batch = 1
129129
with self.assertRaises(AssertionError):
130-
feature_importance.attribute(inp)
130+
feature_importance.attribute(inp, enable_cross_tensor_attribution=False)
131131

132132
def test_simple_input_with_min_examples_in_group(self) -> None:
133133
def forward_func(x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)