@@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None:
220
220
perturbations_per_eval = (1 , 2 , 3 ),
221
221
)
222
222
223
+ def test_multi_input_ablation_with_int_input_tensor_and_float_baseline (
224
+ self ,
225
+ ) -> None :
226
+ def sum_forward (* inps : torch .Tensor ) -> torch .Tensor :
227
+ flattened = [torch .flatten (inp , start_dim = 1 ) for inp in inps ]
228
+ return torch .cat (flattened , dim = 1 ).sum (1 )
229
+
230
+ ablation_algo = FeatureAblation (sum_forward )
231
+ inp1 = torch .tensor ([[0 , 1 ], [3 , 4 ]])
232
+ inp2 = torch .tensor (
233
+ [
234
+ [[0.1 , 0.2 ], [0.3 , 0.2 ]],
235
+ [[0.4 , 0.5 ], [0.3 , 0.2 ]],
236
+ ]
237
+ )
238
+ inp3 = torch .tensor ([[0 ], [1 ]])
239
+
240
+ expected = (
241
+ torch .tensor ([[- 0.2 , 0.8 ], [2.8 , 3.8 ]]),
242
+ torch .tensor (
243
+ [
244
+ [[- 3.0 , - 2.9 ], [- 2.8 , - 2.9 ]],
245
+ [[- 2.7 , - 2.6 ], [- 2.8 , - 2.9 ]],
246
+ ]
247
+ ),
248
+ torch .tensor ([[- 0.4 ], [0.6 ]]),
249
+ )
250
+ self ._ablation_test_assert (
251
+ ablation_algo ,
252
+ (inp1 , inp2 , inp3 ),
253
+ expected ,
254
+ target = None ,
255
+ baselines = (0.2 , 3.1 , 0.4 ),
256
+ test_enable_cross_tensor_attribution = [False , True ],
257
+ )
258
+
223
259
def test_multi_input_ablation_with_mask_weighted (self ) -> None :
224
260
ablation_algo = FeatureAblation (BasicModel_MultiLayer_MultiInput ())
225
261
ablation_algo .use_weights = True
0 commit comments