From 1f6c2d56c7e99d74d9b66c50271422d9310e61b9 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 31 Dec 2021 08:23:09 -0700 Subject: [PATCH 1/4] Improve ModuleOutputsHook, testing coverage, & fix bug * Added the `_remove_all_forward_hooks` function for easy cleanup and removal of hooks without requiring their handles. * Changed `ModuleOutputHook`'s forward hook function name from `forward_hook` to `module_outputs_forward_hook` to allow for easy removal of only hooks using that hook function. * `ModuleOutputHook`'s initialization function now runs the `_remove_all_forward_hooks` function on targets, and only removes the hooks created by `ModuleOutputHook` to avoid breaking PyTorch. * Added the `_count_forward_hooks` function for easy testing of hook creation & removal functionality. * Added tests for verifying that the 'ghost hook' bug has been fixed, and that the new function is working correctly. * Added tests for `ModuleOutputsHook`. Previously we had no tests for this module. --- captum/optim/_core/output_hook.py | 64 ++++- tests/optim/core/test_output_hook.py | 341 ++++++++++++++++++++++++++- 2 files changed, 397 insertions(+), 8 deletions(-) diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 6cfbc4ff2e..147862a129 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -1,5 +1,6 @@ import warnings -from typing import Callable, Iterable, Tuple +from collections import OrderedDict +from typing import Callable, Dict, Iterable, Optional, Tuple from warnings import warn import torch @@ -15,6 +16,9 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None: target_modules (Iterable of nn.Module): A list of nn.Module targets. """ + for module in target_modules: + # Clean up any old hooks that weren't properly deleted + _remove_all_forward_hooks(module, "module_outputs_forward_hook") self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) self.hooks = [ module.register_forward_hook(self._forward_hook()) @@ -33,13 +37,13 @@ def is_ready(self) -> bool: def _forward_hook(self) -> Callable: """ - Return the forward_hook function. + Return the module_outputs_forward_hook forward hook function. Returns: - forward_hook (Callable): The forward_hook function. + forward_hook (Callable): The module_outputs_forward_hook function. """ - def forward_hook( + def module_outputs_forward_hook( module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor ) -> None: assert module in self.outputs.keys() @@ -57,7 +61,7 @@ def forward_hook( "that you are passing model layers in your losses." ) - return forward_hook + return module_outputs_forward_hook def consume_outputs(self) -> ModuleOutputMapping: """ @@ -130,3 +134,53 @@ def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: finally: self.layers.remove_hooks() return activations_dict + + +def _remove_all_forward_hooks( + module: torch.nn.Module, hook_fn_name: Optional[str] = None +) -> None: + """ + This function removes all forward hooks in the specified module, without requiring + any hook handles. This lets us clean up & remove any hooks that weren't property + deleted. + + Warning: Various PyTorch modules and systems make use of hooks, and thus extreme + caution should be exercised when removing all hooks. Users are recommended to give + their hook function a unique name that can be used to safely identify and remove + the target forward hooks. + + Args: + + module (nn.Module): The module instance to remove forward hooks from. + hook_fn_name (str, optional): Optionally only remove specific forward hooks + based on their function's __name__ attribute. + Default: None + """ + + if hook_fn_name is None: + warn("Removing all active hooks will break some PyTorch modules & systems.") + + def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: + if hasattr(module, "_forward_hooks"): + if m._forward_hooks != OrderedDict(): + if name is not None: + dict_items = list(m._forward_hooks.items()) + m._forward_hooks = OrderedDict( + [(i, fn) for i, fn in dict_items if fn.__name__ != name] + ) + else: + m._forward_hooks: Dict[int, Callable] = OrderedDict() + + def _remove_child_hooks( + target_module: torch.nn.Module, hook_name: Optional[str] = None + ) -> None: + for name, child in target_module._modules.items(): + if child is not None: + _remove_hooks(child, hook_name) + _remove_child_hooks(child, hook_name) + + # Remove hooks from target submodules + _remove_child_hooks(module, hook_fn_name) + + # Remove hooks from the target module + _remove_hooks(module, hook_fn_name) diff --git a/tests/optim/core/test_output_hook.py b/tests/optim/core/test_output_hook.py index 075f21d961..f945f73e6a 100644 --- a/tests/optim/core/test_output_hook.py +++ b/tests/optim/core/test_output_hook.py @@ -1,16 +1,205 @@ #!/usr/bin/env python3 import unittest -from typing import cast +from collections import OrderedDict +from typing import List, Optional, Tuple, cast import torch import captum.optim._core.output_hook as output_hook from captum.optim.models import googlenet -from tests.helpers.basic import BaseTest +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual + + +def _count_forward_hooks( + module: torch.nn.Module, hook_fn_name: Optional[str] = None +) -> int: + """ + Count the number of active forward hooks on the specified module instance. + + Args: + + module (nn.Module): The model module instance to count the number of + forward hooks on. + name (str, optional): Optionally only count specific forward hooks based on + their function's __name__ attribute. + Default: None + + Returns: + num_hooks (int): The number of active hooks in the specified module. + """ + + num_hooks: List[int] = [0] + + def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: + if hasattr(m, "_forward_hooks"): + if m._forward_hooks != OrderedDict(): + dict_items = list(m._forward_hooks.items()) + for i, fn in dict_items: + if hook_fn_name is None or fn.__name__ == name: + num_hooks[0] += 1 + + def _count_child_hooks( + target_module: torch.nn.Module, + hook_name: Optional[str] = None, + ) -> None: + + for name, child in target_module._modules.items(): + if child is not None: + _count_hooks(child, hook_name) + _count_child_hooks(child, hook_name) + + _count_child_hooks(module, hook_fn_name) + _count_hooks(module, hook_fn_name) + return num_hooks[0] + + +class TestModuleOutputsHook(BaseTest): + def test_init_single_target(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertEqual(len(hook_module.hooks), len(target_modules)) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + outputs = dict.fromkeys(target_modules, None) + self.assertEqual(outputs, hook_module.outputs) + self.assertEqual(list(hook_module.targets), target_modules) + self.assertFalse(hook_module.is_ready) + + def test_init_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertEqual(len(hook_module.hooks), len(target_modules)) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + outputs = dict.fromkeys(target_modules, None) + self.assertEqual(outputs, hook_module.outputs) + self.assertEqual(list(hook_module.targets), target_modules) + self.assertFalse(hook_module.is_ready) + + def test_init_hook_duplication_fix(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + for i in range(5): + _ = output_hook.ModuleOutputsHook([model[1]]) + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, 1) + + def test_init_multiple_targets_remove_hooks(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + hook_module.remove_hooks() + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, 0) + + def test_reset_outputs_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + outputs_dict = hook_module.outputs + i = 0 + for target, activations in outputs_dict.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + hook_module._reset_outputs() + + self.assertFalse(hook_module.is_ready) + + expected_outputs = dict.fromkeys(target_modules, None) + self.assertEqual(hook_module.outputs, expected_outputs) + + def test_consume_outputs_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + test_outputs_dict = hook_module.outputs + self.assertIsInstance(test_outputs_dict, dict) + self.assertEqual(len(test_outputs_dict), len(target_modules)) + + i = 0 + for target, activations in test_outputs_dict.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + test_output = hook_module.consume_outputs() + + self.assertFalse(hook_module.is_ready) + + i = 0 + for target, activations in test_output.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + expected_outputs = dict.fromkeys(target_modules, None) + self.assertEqual(hook_module.outputs, expected_outputs) + + def test_consume_outputs_warning(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + hook_module._reset_outputs() + + self.assertFalse(hook_module.is_ready) + + with self.assertWarns(Warning): + _ = hook_module.consume_outputs() class TestActivationFetcher(BaseTest): - def test_activation_fetcher(self) -> None: + def test_activation_fetcher_simple_model(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + + catch_activ = output_hook.ActivationFetcher(model, targets=[model[0]]) + test_input = torch.randn(1, 3, 224, 224) + activ_out = catch_activ(test_input) + + self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 1) + activ = activ_out[model[0]] + assertTensorAlmostEqual(self, activ, test_input) + + def test_activation_fetcher_single_target(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( "Skipping ActivationFetcher test due to insufficient Torch version." @@ -21,5 +210,151 @@ def test_activation_fetcher(self) -> None: activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 1) m4d_activ = activ_out[model.mixed4d] self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) + + def test_activation_fetcher_multiple_targets(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping ActivationFetcher test due to insufficient Torch version." + ) + model = googlenet(pretrained=True) + + catch_activ = output_hook.ActivationFetcher( + model, targets=[model.mixed4d, model.mixed5b] + ) + activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) + + self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 2) + + m4d_activ = activ_out[model.mixed4d] + self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) + + m5b_activ = activ_out[model.mixed5b] + self.assertEqual(list(cast(torch.Tensor, m5b_activ).shape), [1, 1024, 7, 7]) + + +class TestRemoveAllForwardHooks(BaseTest): + def test_forward_hook_removal(self) -> None: + def forward_hook_unique_fn( + self, input: Tuple[torch.Tensor], output: torch.Tensor + ) -> None: + pass + + fn_name = forward_hook_unique_fn.__name__ + + layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + model = torch.nn.Sequential(layer1, layer2) + + model.register_forward_hook(forward_hook_unique_fn) + model[1].register_forward_hook(forward_hook_unique_fn) + model[0][1].register_forward_hook(forward_hook_unique_fn) + + n_hooks = _count_forward_hooks(model, fn_name) + self.assertEqual(n_hooks, 3) + + output_hook._remove_all_forward_hooks(model, fn_name) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 0) + + def test_forward_hook_removal_empty_hook_dicts(self) -> None: + def forward_hook_unique_fn( + self, input: Tuple[torch.Tensor], output: torch.Tensor + ) -> None: + pass + + fn_name = forward_hook_unique_fn.__name__ + + layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + model = torch.nn.Sequential(layer1, layer2) + + model[1].register_forward_hook(forward_hook_unique_fn) + model[0][1].register_forward_hook(forward_hook_unique_fn) + + n_hooks = _count_forward_hooks(model, fn_name) + self.assertEqual(n_hooks, 2) + + output_hook._remove_all_forward_hooks(model, fn_name) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 0) + + model[1].register_forward_hook(forward_hook_unique_fn) + model[1][1].register_forward_hook(forward_hook_unique_fn) + + n_hooks = _count_forward_hooks(model, fn_name) + self.assertEqual(n_hooks, 2) + + output_hook._remove_all_forward_hooks(model, fn_name) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 0) + + def test_forward_hook_removal_unique_fn(self) -> None: + def forward_hook_unique_fn_1( + self, input: Tuple[torch.Tensor], output: torch.Tensor + ) -> None: + pass + + def forward_hook_unique_fn_2( + self, input: Tuple[torch.Tensor], output: torch.Tensor + ) -> None: + pass + + fn_name_1 = forward_hook_unique_fn_1.__name__ + fn_name_2 = forward_hook_unique_fn_2.__name__ + + layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + model = torch.nn.Sequential(layer1, layer2) + + model.register_forward_hook(forward_hook_unique_fn_1) + model[1].register_forward_hook(forward_hook_unique_fn_1) + model[0][1].register_forward_hook(forward_hook_unique_fn_1) + + model.register_forward_hook(forward_hook_unique_fn_2) + model[1][0].register_forward_hook(forward_hook_unique_fn_2) + + n_hooks = _count_forward_hooks(model, fn_name_1) + self.assertEqual(n_hooks, 3) + n_hooks = _count_forward_hooks(model, fn_name_2) + self.assertEqual(n_hooks, 2) + + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 5) + + output_hook._remove_all_forward_hooks(model, fn_name_1) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 2) + + output_hook._remove_all_forward_hooks(model, fn_name_2) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 0) + + def test_forward_hook_removal_no_hook_fn_name(self) -> None: + def forward_hook_unique_fn( + self, input: Tuple[torch.Tensor], output: torch.Tensor + ) -> None: + pass + + fn_name = forward_hook_unique_fn.__name__ + + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + + model[0].register_forward_hook(forward_hook_unique_fn) + model[1].register_forward_hook(forward_hook_unique_fn) + + n_hooks = _count_forward_hooks(model, fn_name) + self.assertEqual(n_hooks, 2) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 2) + + with self.assertWarns(Warning): + output_hook._remove_all_forward_hooks(model) + + n_hooks = _count_forward_hooks(model, fn_name) + self.assertEqual(n_hooks, 0) + n_hooks = _count_forward_hooks(model) + self.assertEqual(n_hooks, 0) From 66a3aa0ff9902d61076f1dee1d3a2d3c218e0c8c Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 18 Jan 2022 10:19:09 -0700 Subject: [PATCH 2/4] Make hook fix optional --- captum/optim/__init__.py | 2 ++ captum/optim/_core/output_hook.py | 34 ++++++++++++++++++++++++---- tests/optim/core/test_output_hook.py | 1 + 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index dc16251393..53fe426ff1 100644 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -3,6 +3,7 @@ from captum.optim import models from captum.optim._core import loss, optimization # noqa: F401 from captum.optim._core.optimization import InputOptimization # noqa: F401 +from captum.optim._core.output_hook import cleanup_module_hooks # noqa: F401 from captum.optim._param.image import images, transforms # noqa: F401 from captum.optim._param.image.images import ImageTensor # noqa: F401 from captum.optim._utils import circuits, reducer # noqa: F401 @@ -27,4 +28,5 @@ "save_tensor_as_image", "show", "weights_to_heatmap_2d", + "cleanup_module_hooks", ] diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 147862a129..7476df25b0 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -1,6 +1,6 @@ import warnings from collections import OrderedDict -from typing import Callable, Dict, Iterable, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from warnings import warn import torch @@ -16,9 +16,6 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None: target_modules (Iterable of nn.Module): A list of nn.Module targets. """ - for module in target_modules: - # Clean up any old hooks that weren't properly deleted - _remove_all_forward_hooks(module, "module_outputs_forward_hook") self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) self.hooks = [ module.register_forward_hook(self._forward_hook()) @@ -184,3 +181,32 @@ def _remove_child_hooks( # Remove hooks from the target module _remove_hooks(module, hook_fn_name) + + +def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]]) -> None: + """ + Remove any InputOptimization hooks from the specified modules. This may be useful + in the event that something goes wrong in between creating the InputOptimization + instance and running the optimization function, or if InputOptimization fails + without properly removing it's hooks. + + Warning: This function will remove all the hooks placed by InputOptimization + instances on the target modules, and thus can interfere with using multiple + InputOptimization instances. + + Args: + + modules (nn.Module or list of nn.Module): Any module instances that contain + hooks created by InputOptimization, for which the removal of the hooks is + required. + """ + if not hasattr(modules, "__iter__"): + modules = [modules] + # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions + [ + _remove_all_forward_hooks(module, "module_outputs_forward_hook") + for module in modules + ] + + +__all__ = ["cleanup_module_hooks"] diff --git a/tests/optim/core/test_output_hook.py b/tests/optim/core/test_output_hook.py index f945f73e6a..2897ef96a2 100644 --- a/tests/optim/core/test_output_hook.py +++ b/tests/optim/core/test_output_hook.py @@ -87,6 +87,7 @@ def test_init_multiple_targets(self) -> None: def test_init_hook_duplication_fix(self) -> None: model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) for i in range(5): + output_hook.cleanup_module_hooks(model) _ = output_hook.ModuleOutputsHook([model[1]]) n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") self.assertEqual(n_hooks, 1) From da3d028f928de58146cea1d285bc0b32a704a812 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 6 Apr 2022 13:44:50 -0600 Subject: [PATCH 3/4] Remove hacky hook fix --- captum/optim/__init__.py | 2 - captum/optim/_core/output_hook.py | 84 +---------------- tests/optim/core/test_output_hook.py | 134 +-------------------------- 3 files changed, 3 insertions(+), 217 deletions(-) diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py index 53fe426ff1..dc16251393 100644 --- a/captum/optim/__init__.py +++ b/captum/optim/__init__.py @@ -3,7 +3,6 @@ from captum.optim import models from captum.optim._core import loss, optimization # noqa: F401 from captum.optim._core.optimization import InputOptimization # noqa: F401 -from captum.optim._core.output_hook import cleanup_module_hooks # noqa: F401 from captum.optim._param.image import images, transforms # noqa: F401 from captum.optim._param.image.images import ImageTensor # noqa: F401 from captum.optim._utils import circuits, reducer # noqa: F401 @@ -28,5 +27,4 @@ "save_tensor_as_image", "show", "weights_to_heatmap_2d", - "cleanup_module_hooks", ] diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 7476df25b0..7425d601fe 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -1,10 +1,9 @@ import warnings -from collections import OrderedDict -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union -from warnings import warn +from typing import Callable, Iterable, Tuple import torch import torch.nn as nn +from warnings import warn from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType @@ -131,82 +130,3 @@ def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping: finally: self.layers.remove_hooks() return activations_dict - - -def _remove_all_forward_hooks( - module: torch.nn.Module, hook_fn_name: Optional[str] = None -) -> None: - """ - This function removes all forward hooks in the specified module, without requiring - any hook handles. This lets us clean up & remove any hooks that weren't property - deleted. - - Warning: Various PyTorch modules and systems make use of hooks, and thus extreme - caution should be exercised when removing all hooks. Users are recommended to give - their hook function a unique name that can be used to safely identify and remove - the target forward hooks. - - Args: - - module (nn.Module): The module instance to remove forward hooks from. - hook_fn_name (str, optional): Optionally only remove specific forward hooks - based on their function's __name__ attribute. - Default: None - """ - - if hook_fn_name is None: - warn("Removing all active hooks will break some PyTorch modules & systems.") - - def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: - if hasattr(module, "_forward_hooks"): - if m._forward_hooks != OrderedDict(): - if name is not None: - dict_items = list(m._forward_hooks.items()) - m._forward_hooks = OrderedDict( - [(i, fn) for i, fn in dict_items if fn.__name__ != name] - ) - else: - m._forward_hooks: Dict[int, Callable] = OrderedDict() - - def _remove_child_hooks( - target_module: torch.nn.Module, hook_name: Optional[str] = None - ) -> None: - for name, child in target_module._modules.items(): - if child is not None: - _remove_hooks(child, hook_name) - _remove_child_hooks(child, hook_name) - - # Remove hooks from target submodules - _remove_child_hooks(module, hook_fn_name) - - # Remove hooks from the target module - _remove_hooks(module, hook_fn_name) - - -def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]]) -> None: - """ - Remove any InputOptimization hooks from the specified modules. This may be useful - in the event that something goes wrong in between creating the InputOptimization - instance and running the optimization function, or if InputOptimization fails - without properly removing it's hooks. - - Warning: This function will remove all the hooks placed by InputOptimization - instances on the target modules, and thus can interfere with using multiple - InputOptimization instances. - - Args: - - modules (nn.Module or list of nn.Module): Any module instances that contain - hooks created by InputOptimization, for which the removal of the hooks is - required. - """ - if not hasattr(modules, "__iter__"): - modules = [modules] - # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions - [ - _remove_all_forward_hooks(module, "module_outputs_forward_hook") - for module in modules - ] - - -__all__ = ["cleanup_module_hooks"] diff --git a/tests/optim/core/test_output_hook.py b/tests/optim/core/test_output_hook.py index 2897ef96a2..1c4b4aaff4 100644 --- a/tests/optim/core/test_output_hook.py +++ b/tests/optim/core/test_output_hook.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest from collections import OrderedDict -from typing import List, Optional, Tuple, cast +from typing import List, Optional, cast import torch @@ -84,14 +84,6 @@ def test_init_multiple_targets(self) -> None: self.assertEqual(list(hook_module.targets), target_modules) self.assertFalse(hook_module.is_ready) - def test_init_hook_duplication_fix(self) -> None: - model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - for i in range(5): - output_hook.cleanup_module_hooks(model) - _ = output_hook.ModuleOutputsHook([model[1]]) - n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") - self.assertEqual(n_hooks, 1) - def test_init_multiple_targets_remove_hooks(self) -> None: model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) target_modules = [model[0], model[1]] @@ -235,127 +227,3 @@ def test_activation_fetcher_multiple_targets(self) -> None: m5b_activ = activ_out[model.mixed5b] self.assertEqual(list(cast(torch.Tensor, m5b_activ).shape), [1, 1024, 7, 7]) - - -class TestRemoveAllForwardHooks(BaseTest): - def test_forward_hook_removal(self) -> None: - def forward_hook_unique_fn( - self, input: Tuple[torch.Tensor], output: torch.Tensor - ) -> None: - pass - - fn_name = forward_hook_unique_fn.__name__ - - layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - model = torch.nn.Sequential(layer1, layer2) - - model.register_forward_hook(forward_hook_unique_fn) - model[1].register_forward_hook(forward_hook_unique_fn) - model[0][1].register_forward_hook(forward_hook_unique_fn) - - n_hooks = _count_forward_hooks(model, fn_name) - self.assertEqual(n_hooks, 3) - - output_hook._remove_all_forward_hooks(model, fn_name) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 0) - - def test_forward_hook_removal_empty_hook_dicts(self) -> None: - def forward_hook_unique_fn( - self, input: Tuple[torch.Tensor], output: torch.Tensor - ) -> None: - pass - - fn_name = forward_hook_unique_fn.__name__ - - layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - model = torch.nn.Sequential(layer1, layer2) - - model[1].register_forward_hook(forward_hook_unique_fn) - model[0][1].register_forward_hook(forward_hook_unique_fn) - - n_hooks = _count_forward_hooks(model, fn_name) - self.assertEqual(n_hooks, 2) - - output_hook._remove_all_forward_hooks(model, fn_name) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 0) - - model[1].register_forward_hook(forward_hook_unique_fn) - model[1][1].register_forward_hook(forward_hook_unique_fn) - - n_hooks = _count_forward_hooks(model, fn_name) - self.assertEqual(n_hooks, 2) - - output_hook._remove_all_forward_hooks(model, fn_name) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 0) - - def test_forward_hook_removal_unique_fn(self) -> None: - def forward_hook_unique_fn_1( - self, input: Tuple[torch.Tensor], output: torch.Tensor - ) -> None: - pass - - def forward_hook_unique_fn_2( - self, input: Tuple[torch.Tensor], output: torch.Tensor - ) -> None: - pass - - fn_name_1 = forward_hook_unique_fn_1.__name__ - fn_name_2 = forward_hook_unique_fn_2.__name__ - - layer1 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - layer2 = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - model = torch.nn.Sequential(layer1, layer2) - - model.register_forward_hook(forward_hook_unique_fn_1) - model[1].register_forward_hook(forward_hook_unique_fn_1) - model[0][1].register_forward_hook(forward_hook_unique_fn_1) - - model.register_forward_hook(forward_hook_unique_fn_2) - model[1][0].register_forward_hook(forward_hook_unique_fn_2) - - n_hooks = _count_forward_hooks(model, fn_name_1) - self.assertEqual(n_hooks, 3) - n_hooks = _count_forward_hooks(model, fn_name_2) - self.assertEqual(n_hooks, 2) - - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 5) - - output_hook._remove_all_forward_hooks(model, fn_name_1) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 2) - - output_hook._remove_all_forward_hooks(model, fn_name_2) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 0) - - def test_forward_hook_removal_no_hook_fn_name(self) -> None: - def forward_hook_unique_fn( - self, input: Tuple[torch.Tensor], output: torch.Tensor - ) -> None: - pass - - fn_name = forward_hook_unique_fn.__name__ - - model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) - - model[0].register_forward_hook(forward_hook_unique_fn) - model[1].register_forward_hook(forward_hook_unique_fn) - - n_hooks = _count_forward_hooks(model, fn_name) - self.assertEqual(n_hooks, 2) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 2) - - with self.assertWarns(Warning): - output_hook._remove_all_forward_hooks(model) - - n_hooks = _count_forward_hooks(model, fn_name) - self.assertEqual(n_hooks, 0) - n_hooks = _count_forward_hooks(model) - self.assertEqual(n_hooks, 0) From 3d5c3d5a624b1c322e5a6e13c8c057628a2ec51d Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 8 Apr 2022 13:21:00 -0600 Subject: [PATCH 4/4] Lint: Fix import order --- captum/optim/_core/output_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 7425d601fe..160fe5b8a8 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -1,9 +1,9 @@ import warnings from typing import Callable, Iterable, Tuple +from warnings import warn import torch import torch.nn as nn -from warnings import warn from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType