diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index db9066ceba..3f8ab79a22 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -126,9 +126,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return math_op(torch.mean(self(module)), torch.mean(other(module))) name = f"Compose({', '.join([self.__name__, other.__name__])})" + + # ToDo: Refine logic for self.target handling target = (self.target if isinstance(self.target, list) else [self.target]) + ( other.target if isinstance(other.target, list) else [other.target] ) + + # Filter out duplicate targets + target = list(dict.fromkeys(target)) else: raise TypeError( "Can only apply math operations with int, float or Loss. Received type " @@ -720,6 +725,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: ] for target in targets ] + + # Filter out duplicate targets + target = list(dict.fromkeys(target)) return CompositeLoss(loss_fn, name=name, target=target) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index 3de775157b..5e65cf81c6 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -193,6 +193,7 @@ def collect_activations( """ if not isinstance(targets, list): targets = [targets] + targets = list(dict.fromkeys(targets)) catch_activ = ActivationFetcher(model, targets) activ_out = catch_activ(model_input) return activ_out