From bf8c7d42c5b1af1d56b11e055e08253db82f0bad Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Tue, 13 Dec 2022 23:53:47 -0800 Subject: [PATCH 1/3] validate forward output in FeatureAblation --- captum/attr/_core/feature_ablation.py | 84 +++++++++++++-------------- tests/attr/test_feature_ablation.py | 11 ---- 2 files changed, 41 insertions(+), 54 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index fab34221ba..7139ac755e 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -53,6 +53,15 @@ def __init__(self, forward_func: Callable) -> None: PerturbationAttribution.__init__(self, forward_func) self.use_weights = False + # only used when perturbations_per_eval > 1, where the 1st dim of forward's + # output grow as the input batch. If forward's output is aggregated, + # we cannot expand the batch size to include more perturbations in one call. + # If it's False, we will check with an additional run where + # perturbations_per_eval = 1 to see if the output shape is expected; + # but it turns to True, we will assume the model's hehavior stays + # consistant and no longer check again + self._is_output_shape_valid = False + @log_usage() def attribute( self, @@ -291,21 +300,10 @@ def attribute( # flatten eval outputs into 1D (n_outputs) # add the leading dim for n_feature_perturbed - initial_eval = initial_eval.reshape(1, -1) - - agg_output_mode = FeatureAblation._find_output_mode( - perturbations_per_eval, feature_mask - ) - - if not agg_output_mode: - assert n_outputs == num_examples, ( - "expected output of `forward_func` to have " - + "`batch_size` elements for perturbations_per_eval > 1 " - + "and all feature_mask.shape[0] > 1" - ) + flattened_initial_eval = initial_eval.reshape(1, -1) # Initialize attribution totals and counts - attrib_type = cast(dtype, initial_eval.dtype) + attrib_type = cast(dtype, flattened_initial_eval.dtype) total_attrib = [ # attribute w.r.t each output element @@ -362,21 +360,43 @@ def attribute( if show_progress: attr_progress.update() - if not agg_output_mode: - # current_batch_size is not n_examples - # it may get expanded by n_feature_perturbed + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + if perturbations_per_eval > 1 and not self._is_output_shape_valid: current_batch_size = current_inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + assert ( - modified_eval.numel() == current_batch_size - ), """expected output of forward_func to grow with - batch_size. If this is not the case for your model - please set perturbations_per_eval = 1""" + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] + == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True # reshape the leading dim for n_feature_perturbed # flatten each feature's eval outputs into 1D of (n_outputs) modified_eval = modified_eval.reshape(-1, n_outputs) # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = initial_eval - modified_eval + eval_diff = flattened_initial_eval - modified_eval # append the shape of one input example # to make it broadcastable to mask @@ -572,28 +592,6 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs): for inp, mask in zip(inputs, feature_mask) ) - @staticmethod - def _find_output_mode( - perturbations_per_eval: int, - feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], - ) -> bool: - """ - Returns True if the output mode is "aggregation output mode" - - Aggregation output mode is defined as: when there is no 1:1 correspondence - with the `num_examples` (`batch_size`) and the amount of outputs your model - produces, i.e. the model output does not grow in size as the input becomes - larger. - - We assume this is the case if `perturbations_per_eval == 1` - and your feature mask is None or is associated to all - examples in a batch (fm.shape[0] == 1 for all fm in feature_mask). - """ - return perturbations_per_eval == 1 and ( - feature_mask is None - or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) - ) - def _strict_run_forward(self, *args, **kwargs) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 290b9bf265..91ff63d259 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -345,17 +345,6 @@ def forward_func(inp): with self.assertRaises(AssertionError): _ = ablation.attribute(inp, perturbations_per_eval=2) - def test_error_agg_mode_incorrect_fm(self) -> None: - def forward_func(inp): - return inp[0].unsqueeze(0) - - inp = torch.tensor([[1, 2, 3], [4, 5, 6]]) - mask = torch.tensor([[0, 1, 2], [0, 0, 1]]) - - ablation = FeatureAblation(forward_func) - with self.assertRaises(AssertionError): - _ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask) - def test_empty_sparse_features(self) -> None: ablation_algo = FeatureAblation(BasicModelWithSparseInputs()) inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]]) From 2f11d84d6fdb7f65ff77b4c4383143ed20c3ed04 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Wed, 14 Dec 2022 10:57:23 -0800 Subject: [PATCH 2/3] typo --- captum/attr/_core/feature_ablation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 7139ac755e..deb9a2b22e 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -59,7 +59,7 @@ def __init__(self, forward_func: Callable) -> None: # If it's False, we will check with an additional run where # perturbations_per_eval = 1 to see if the output shape is expected; # but it turns to True, we will assume the model's hehavior stays - # consistant and no longer check again + # consistent and no longer check again self._is_output_shape_valid = False @log_usage() From 19379e1f1147146ab028482157afc15cde1bef2b Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Mon, 19 Dec 2022 16:12:39 -0800 Subject: [PATCH 3/3] wording --- captum/attr/_core/feature_ablation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index deb9a2b22e..6355b62bd3 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -53,13 +53,14 @@ def __init__(self, forward_func: Callable) -> None: PerturbationAttribution.__init__(self, forward_func) self.use_weights = False - # only used when perturbations_per_eval > 1, where the 1st dim of forward's - # output grow as the input batch. If forward's output is aggregated, - # we cannot expand the batch size to include more perturbations in one call. - # If it's False, we will check with an additional run where - # perturbations_per_eval = 1 to see if the output shape is expected; - # but it turns to True, we will assume the model's hehavior stays - # consistent and no longer check again + # only used when perturbations_per_eval > 1, where the 1st dim of forward_func's + # output must grow as the input batch size. If forward's output is aggregated, + # we cannot expand the input to include more perturbations in one call. + # If it's False, we will force the validation by comparing the outpus of + # the original input and the modified input whose batch size expanded based on + # perturbations_per_eval. Set the flag to True if the output of the modified + # input grow as expected. Once it turns to True, we will assume the model's + # behavior stays consistent and no longer check again self._is_output_shape_valid = False @log_usage()