Skip to content

Commit c50c50b

Browse files
authored
Optim-wip: Fix duplicated target bug (#919)
* Fix duplicated target bug * Fix duplicated target bug in `sum_loss_list` & `collect_activations` * Add ToDo comment for target handling
1 parent d272be7 commit c50c50b

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

captum/optim/_core/loss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
126126
return math_op(torch.mean(self(module)), torch.mean(other(module)))
127127

128128
name = f"Compose({', '.join([self.__name__, other.__name__])})"
129+
130+
# ToDo: Refine logic for self.target handling
129131
target = (self.target if isinstance(self.target, list) else [self.target]) + (
130132
other.target if isinstance(other.target, list) else [other.target]
131133
)
134+
135+
# Filter out duplicate targets
136+
target = list(dict.fromkeys(target))
132137
else:
133138
raise TypeError(
134139
"Can only apply math operations with int, float or Loss. Received type "
@@ -875,6 +880,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
875880
]
876881
for target in targets
877882
]
883+
884+
# Filter out duplicate targets
885+
target = list(dict.fromkeys(target))
878886
return CompositeLoss(loss_fn, name=name, target=target)
879887

880888

captum/optim/models/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def collect_activations(
255255
"""
256256
if not isinstance(targets, list):
257257
targets = [targets]
258+
targets = list(dict.fromkeys(targets))
258259
catch_activ = ActivationFetcher(model, targets)
259260
activ_dict = catch_activ(model_input)
260261
return activ_dict

0 commit comments

Comments
 (0)