Skip to content

Commit a44565f

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Factor out _get_feature_idx_to_tensor_idx from FeatureAblation (#1645)
Summary: Still need to keep the instance method since it's overriden by occlusion: https://www.internalfb.com/code/fbsource/[8178d370745cf77ea10b992de23e6add428711fd]/fbcode/pytorch/captum/captum/attr/_core/occlusion.py?lines=399 Need to use function later in the stack for within-group SVS. Differential Revision: D82169482
1 parent c9fc87d commit a44565f

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

captum/_utils/common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,22 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
908908
return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel()))
909909

910910

911+
def _get_feature_idx_to_tensor_idx(
912+
formatted_feature_mask: Tuple[Tensor, ...],
913+
) -> Dict[int, List[int]]:
914+
"""
915+
For a given tuple of tensors, return dict of tensor values to list of tensor
916+
indices they appear in.
917+
"""
918+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
919+
for i, mask in enumerate(formatted_feature_mask):
920+
for feature_idx in torch.unique(mask):
921+
if feature_idx.item() not in feature_idx_to_tensor_idx:
922+
feature_idx_to_tensor_idx[feature_idx.item()] = []
923+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
924+
return feature_idx_to_tensor_idx
925+
926+
911927
def _maybe_expand_parameters(
912928
perturbations_per_eval: int,
913929
formatted_inputs: Tuple[Tensor, ...],

captum/attr/_core/feature_ablation.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_format_additional_forward_args,
2525
_format_feature_mask,
2626
_format_output,
27+
_get_feature_idx_to_tensor_idx,
2728
_is_tuple,
2829
_maybe_expand_parameters,
2930
_run_forward,
@@ -616,17 +617,7 @@ def _attribute_with_cross_tensor_feature_masks(
616617
def _get_feature_idx_to_tensor_idx(
617618
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
618619
) -> Dict[int, List[int]]:
619-
"""
620-
For a given tuple of tensors, return dict of tensor values to list of tensor
621-
indices they appear in.
622-
"""
623-
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
624-
for i, mask in enumerate(formatted_feature_mask):
625-
for feature_idx in torch.unique(mask):
626-
if feature_idx.item() not in feature_idx_to_tensor_idx:
627-
feature_idx_to_tensor_idx[feature_idx.item()] = []
628-
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
629-
return feature_idx_to_tensor_idx
620+
return _get_feature_idx_to_tensor_idx(formatted_feature_mask)
630621

631622
def _should_skip_inputs_and_warn(
632623
self,

0 commit comments

Comments
 (0)