Skip to content

Commit c9fc87d

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Update FeatureAblation to handle precision loss when baseline is more granular than input when cross tensor attribution is enabled
Summary: Noticed when flipping the flag, this test case failed: https://www.internalfb.com/code/fbsource/[faf71541b1ec0fae639f82d487b81fb18ea3e523]/fbcode/pytorch/captum/tests/attr/test_dataloader_attr.py?lines=138%2C134 The last ablated tensor was `tensor([[0], [0]])` instead of `tensor([[0.1], [0.1]])` since the baseline was a float-type (`0.1`) and the input tensors were int tensors. https://www.internalfb.com/code/fbsource/[f2fcc926a6f3669602bac4d28c2d92e4197c96b9]/fbcode/pytorch/captum/captum/attr/_core/feature_ablation.py?lines=707-709 `ablated_input` is just a copy of the `input_tensor`, so during assignment, the ablated feature tensor incorrectly gets cast to an int tensor for this case. PyPer/APS/MVAI FI don't use baselines. Fluent2 supports custom baselines, but only `ZeroFIBaseline` is currently defined. Reviewed By: styusuf Differential Revision: D81980219 fbshipit-source-id: e53f5d66643d0d7f5373de72460515c79fb5c869
1 parent f815abc commit c9fc87d

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,11 @@ def _construct_ablated_input_across_tensors(
704704
tensor_mask.append(mask)
705705

706706
assert baseline is not None, "baseline must be provided"
707-
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
708-
1 - mask
707+
ablated_feature = input_tensor[start_idx:end_idx] * (1 - mask).to(
708+
input_tensor.dtype
709709
) + (baseline * mask.to(input_tensor.dtype))
710+
ablated_input = ablated_input.to(ablated_feature.dtype)
711+
ablated_input[start_idx:end_idx] = ablated_feature
710712
current_masks.append(torch.stack(tensor_mask, dim=0))
711713
ablated_inputs.append(ablated_input)
712714

tests/attr/test_feature_ablation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None:
220220
perturbations_per_eval=(1, 2, 3),
221221
)
222222

223+
def test_multi_input_ablation_with_int_input_tensor_and_float_baseline(
224+
self,
225+
) -> None:
226+
def sum_forward(*inps: torch.Tensor) -> torch.Tensor:
227+
flattened = [torch.flatten(inp, start_dim=1) for inp in inps]
228+
return torch.cat(flattened, dim=1).sum(1)
229+
230+
ablation_algo = FeatureAblation(sum_forward)
231+
inp1 = torch.tensor([[0, 1], [3, 4]])
232+
inp2 = torch.tensor(
233+
[
234+
[[0.1, 0.2], [0.3, 0.2]],
235+
[[0.4, 0.5], [0.3, 0.2]],
236+
]
237+
)
238+
inp3 = torch.tensor([[0], [1]])
239+
240+
expected = (
241+
torch.tensor([[-0.2, 0.8], [2.8, 3.8]]),
242+
torch.tensor(
243+
[
244+
[[-3.0, -2.9], [-2.8, -2.9]],
245+
[[-2.7, -2.6], [-2.8, -2.9]],
246+
]
247+
),
248+
torch.tensor([[-0.4], [0.6]]),
249+
)
250+
self._ablation_test_assert(
251+
ablation_algo,
252+
(inp1, inp2, inp3),
253+
expected,
254+
target=None,
255+
baselines=(0.2, 3.1, 0.4),
256+
test_enable_cross_tensor_attribution=[False, True],
257+
)
258+
223259
def test_multi_input_ablation_with_mask_weighted(self) -> None:
224260
ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput())
225261
ablation_algo.use_weights = True

0 commit comments

Comments
 (0)