File tree Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -126,9 +126,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
126
126
return math_op (torch .mean (self (module )), torch .mean (other (module )))
127
127
128
128
name = f"Compose({ ', ' .join ([self .__name__ , other .__name__ ])} )"
129
+
130
+ # ToDo: Refine logic for self.target handling
129
131
target = (self .target if isinstance (self .target , list ) else [self .target ]) + (
130
132
other .target if isinstance (other .target , list ) else [other .target ]
131
133
)
134
+
135
+ # Filter out duplicate targets
136
+ target = list (dict .fromkeys (target ))
132
137
else :
133
138
raise TypeError (
134
139
"Can only apply math operations with int, float or Loss. Received type "
@@ -875,6 +880,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
875
880
]
876
881
for target in targets
877
882
]
883
+
884
+ # Filter out duplicate targets
885
+ target = list (dict .fromkeys (target ))
878
886
return CompositeLoss (loss_fn , name = name , target = target )
879
887
880
888
Original file line number Diff line number Diff line change @@ -255,6 +255,7 @@ def collect_activations(
255
255
"""
256
256
if not isinstance (targets , list ):
257
257
targets = [targets ]
258
+ targets = list (dict .fromkeys (targets ))
258
259
catch_activ = ActivationFetcher (model , targets )
259
260
activ_dict = catch_activ (model_input )
260
261
return activ_dict
You can’t perform that action at this time.
0 commit comments