Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,22 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel()))


def _get_feature_idx_to_tensor_idx(
formatted_feature_mask: Tuple[Tensor, ...],
) -> Dict[int, List[int]]:
"""
For a given tuple of tensors, return dict of tensor values to list of tensor
indices they appear in.
"""
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
for i, mask in enumerate(formatted_feature_mask):
for feature_idx in torch.unique(mask):
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
return feature_idx_to_tensor_idx


def _maybe_expand_parameters(
perturbations_per_eval: int,
formatted_inputs: Tuple[Tensor, ...],
Expand Down
13 changes: 2 additions & 11 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_format_additional_forward_args,
_format_feature_mask,
_format_output,
_get_feature_idx_to_tensor_idx,
_is_tuple,
_maybe_expand_parameters,
_run_forward,
Expand Down Expand Up @@ -616,17 +617,7 @@ def _attribute_with_cross_tensor_feature_masks(
def _get_feature_idx_to_tensor_idx(
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
) -> Dict[int, List[int]]:
"""
For a given tuple of tensors, return dict of tensor values to list of tensor
indices they appear in.
"""
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
for i, mask in enumerate(formatted_feature_mask):
for feature_idx in torch.unique(mask):
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
return feature_idx_to_tensor_idx
return _get_feature_idx_to_tensor_idx(formatted_feature_mask)

def _should_skip_inputs_and_warn(
self,
Expand Down
Loading