diff --git a/classy_vision/hooks/__init__.py b/classy_vision/hooks/__init__.py index a06e4cd6fe..62485485e6 100644 --- a/classy_vision/hooks/__init__.py +++ b/classy_vision/hooks/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path +from typing import Any, Dict, List from classy_vision.generic.registry_utils import import_all_modules @@ -44,5 +45,66 @@ FILE_ROOT = Path(__file__).parent +HOOK_REGISTRY = {} +HOOK_CLASS_NAMES = set() + + +def register_hook(name): + """Registers a :class:`ClassyHook` subclass. + + This decorator allows Classy Vision to instantiate a subclass of + :class:`ClassyHook` from a configuration file, even if the class + itself is not part of the base Classy Vision framework. To use it, + apply this decorator to a ClassyHook subclass, like this: + + .. code-block:: python + + @register_model('resnet') + class CustomHook(ClassyHook): + ... + + To instantiate a hook from a configuration file, see + :func:`build_model`. + """ + + def register_hook_cls(cls): + if name in HOOK_REGISTRY: + raise ValueError("Cannot register duplicate hook ({})".format(name)) + if not issubclass(cls, ClassyHook): + raise ValueError( + "Hook ({}: {}) must extend ClassyHook".format(name, cls.__name__) + ) + if cls.__name__ in HOOK_CLASS_NAMES: + raise ValueError( + "Cannot register model with duplicate class name ({})".format( + cls.__name__ + ) + ) + HOOK_REGISTRY[name] = cls + HOOK_CLASS_NAMES.add(cls.__name__) + return cls + + return register_hook_cls + + +def build_hooks(hook_configs: List[Dict[str, Any]]): + return [build_hook(config) for config in hook_configs] + + +def build_hook(hook_config: Dict[str, Any]): + """Builds a ClassyHook from a config. + + This assumes a 'name' key in the config which is used to determine + what model class to instantiate. For instance, a config `{"name": + "my_hook", "foo": "bar"}` will find a class that was registered as + "my_hook" (see :func:`register_hook`) and call .from_config on + it.""" + assert hook_config["name"] in HOOK_REGISTRY, ( + "Unregistered hook. Did you make sure to use the register_hook decorator " + "AND import the hook file before calling this function??" + ) + return HOOK_REGISTRY[hook_config["name"]].from_config(hook_config) + + # automatically import any Python files in the hooks/ directory import_all_modules(FILE_ROOT, "classy_vision.hooks") diff --git a/classy_vision/hooks/checkpoint_hook.py b/classy_vision/hooks/checkpoint_hook.py index e54541725f..a9a99e434e 100644 --- a/classy_vision/hooks/checkpoint_hook.py +++ b/classy_vision/hooks/checkpoint_hook.py @@ -28,7 +28,7 @@ class CheckpointHook(ClassyHook): def __init__( self, checkpoint_folder: str, - input_args: Any, + input_args: Any = None, phase_types: Optional[Collection[str]] = None, checkpoint_period: int = 1, ) -> None: @@ -59,6 +59,13 @@ def __init__( self.checkpoint_period: int = checkpoint_period self.phase_counter: int = 0 + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CheckpointHook": + assert isinstance( + config["checkpoint_folder"], str + ), "checkpoint_folder must be a string specifying the checkpoint directory" + return CheckpointHook(**config) + def _save_checkpoint(self, task, filename): if getattr(task, "test_only", False): return diff --git a/test/hooks_checkpoint_hook_test.py b/test/hooks_checkpoint_hook_test.py index a612c12682..2360d037d0 100644 --- a/test/hooks_checkpoint_hook_test.py +++ b/test/hooks_checkpoint_hook_test.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy import os import shutil import tempfile @@ -24,6 +25,29 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.base_dir) + def test_constructors(self) -> None: + """ + Test that the hooks are constructed correctly. + """ + config = { + "checkpoint_folder": "/test/", + "input_args": {"foo": "bar"}, + "phase_types": ["train"], + "checkpoint_period": 2, + } + + hook1 = CheckpointHook(**config) + hook2 = CheckpointHook.from_config(config) + + self.assertTrue(isinstance(hook1, CheckpointHook)) + self.assertTrue(isinstance(hook2, CheckpointHook)) + + # Verify assert logic works correctly + with self.assertRaises(AssertionError): + bad_config = copy.deepcopy(config) + bad_config["checkpoint_folder"] = 12 + CheckpointHook.from_config(bad_config) + def test_state_checkpointing(self) -> None: """ Test that the state gets checkpointed without any errors, but only on the diff --git a/test/hooks_classy_hook_test.py b/test/hooks_classy_hook_test.py index e748fb9742..f1464d84ea 100644 --- a/test/hooks_classy_hook_test.py +++ b/test/hooks_classy_hook_test.py @@ -4,11 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy import unittest -from classy_vision.hooks import ClassyHook +from classy_vision.hooks import ClassyHook, build_hook, build_hooks, register_hook +@register_hook("test_hook") class TestHook(ClassyHook): on_start = ClassyHook._noop on_phase_start = ClassyHook._noop @@ -21,8 +23,26 @@ def __init__(self, a, b): self.state.a = a self.state.b = b + @classmethod + def from_config(cls, config): + return TestHook(config["a"], config["b"]) + class TestClassyHook(unittest.TestCase): + def test_hook_registry_and_builder(self): + config = {"name": "test_hook", "a": 1, "b": 2} + hook1 = build_hook(hook_config=config) + self.assertTrue(isinstance(hook1, TestHook)) + self.assertTrue(hook1.state.a == 1) + self.assertTrue(hook1.state.b == 2) + + hook_configs = [copy.deepcopy(config), copy.deepcopy(config)] + hooks = build_hooks(hook_configs=hook_configs) + for hook in hooks: + self.assertTrue(isinstance(hook, TestHook)) + self.assertTrue(hook.state.a == 1) + self.assertTrue(hook.state.b == 2) + def test_state_dict(self): a = 0 b = {1: 2, 3: [4]}