Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add hook builder / registry #401

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions classy_vision/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, Dict, List

from classy_vision.generic.registry_utils import import_all_modules

Expand Down Expand Up @@ -44,5 +45,66 @@

FILE_ROOT = Path(__file__).parent

HOOK_REGISTRY = {}
HOOK_CLASS_NAMES = set()


def register_hook(name):
"""Registers a :class:`ClassyHook` subclass.

This decorator allows Classy Vision to instantiate a subclass of
:class:`ClassyHook` from a configuration file, even if the class
itself is not part of the base Classy Vision framework. To use it,
apply this decorator to a ClassyHook subclass, like this:

.. code-block:: python

@register_model('resnet')
class CustomHook(ClassyHook):
...

To instantiate a hook from a configuration file, see
:func:`build_model`.
"""

def register_hook_cls(cls):
if name in HOOK_REGISTRY:
raise ValueError("Cannot register duplicate hook ({})".format(name))
if not issubclass(cls, ClassyHook):
raise ValueError(
"Hook ({}: {}) must extend ClassyHook".format(name, cls.__name__)
)
if cls.__name__ in HOOK_CLASS_NAMES:
raise ValueError(
"Cannot register model with duplicate class name ({})".format(
cls.__name__
)
)
HOOK_REGISTRY[name] = cls
HOOK_CLASS_NAMES.add(cls.__name__)
return cls

return register_hook_cls


def build_hooks(hook_configs: List[Dict[str, Any]]):
return [build_hook(config) for config in hook_configs]


def build_hook(hook_config: Dict[str, Any]):
"""Builds a ClassyHook from a config.

This assumes a 'name' key in the config which is used to determine
what model class to instantiate. For instance, a config `{"name":
"my_hook", "foo": "bar"}` will find a class that was registered as
"my_hook" (see :func:`register_hook`) and call .from_config on
it."""
assert hook_config["name"] in HOOK_REGISTRY, (
"Unregistered hook. Did you make sure to use the register_hook decorator "
"AND import the hook file before calling this function??"
)
return HOOK_REGISTRY[hook_config["name"]].from_config(hook_config)


# automatically import any Python files in the hooks/ directory
import_all_modules(FILE_ROOT, "classy_vision.hooks")
9 changes: 8 additions & 1 deletion classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CheckpointHook(ClassyHook):
def __init__(
self,
checkpoint_folder: str,
input_args: Any,
input_args: Any = None,
phase_types: Optional[Collection[str]] = None,
checkpoint_period: int = 1,
) -> None:
Expand Down Expand Up @@ -59,6 +59,13 @@ def __init__(
self.checkpoint_period: int = checkpoint_period
self.phase_counter: int = 0

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CheckpointHook":
assert isinstance(
config["checkpoint_folder"], str
), "checkpoint_folder must be a string specifying the checkpoint directory"
return CheckpointHook(**config)

def _save_checkpoint(self, task, filename):
if getattr(task, "test_only", False):
return
Expand Down
24 changes: 24 additions & 0 deletions test/hooks_checkpoint_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import shutil
import tempfile
Expand All @@ -24,6 +25,29 @@ def setUp(self) -> None:
def tearDown(self) -> None:
shutil.rmtree(self.base_dir)

def test_constructors(self) -> None:
"""
Test that the hooks are constructed correctly.
"""
config = {
"checkpoint_folder": "/test/",
"input_args": {"foo": "bar"},
"phase_types": ["train"],
"checkpoint_period": 2,
}

hook1 = CheckpointHook(**config)
hook2 = CheckpointHook.from_config(config)

self.assertTrue(isinstance(hook1, CheckpointHook))
self.assertTrue(isinstance(hook2, CheckpointHook))

# Verify assert logic works correctly
with self.assertRaises(AssertionError):
bad_config = copy.deepcopy(config)
bad_config["checkpoint_folder"] = 12
CheckpointHook.from_config(bad_config)

def test_state_checkpointing(self) -> None:
"""
Test that the state gets checkpointed without any errors, but only on the
Expand Down
22 changes: 21 additions & 1 deletion test/hooks_classy_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

from classy_vision.hooks import ClassyHook
from classy_vision.hooks import ClassyHook, build_hook, build_hooks, register_hook


@register_hook("test_hook")
class TestHook(ClassyHook):
on_start = ClassyHook._noop
on_phase_start = ClassyHook._noop
Expand All @@ -21,8 +23,26 @@ def __init__(self, a, b):
self.state.a = a
self.state.b = b

@classmethod
def from_config(cls, config):
return TestHook(config["a"], config["b"])


class TestClassyHook(unittest.TestCase):
def test_hook_registry_and_builder(self):
config = {"name": "test_hook", "a": 1, "b": 2}
hook1 = build_hook(hook_config=config)
self.assertTrue(isinstance(hook1, TestHook))
self.assertTrue(hook1.state.a == 1)
self.assertTrue(hook1.state.b == 2)

hook_configs = [copy.deepcopy(config), copy.deepcopy(config)]
hooks = build_hooks(hook_configs=hook_configs)
for hook in hooks:
self.assertTrue(isinstance(hook, TestHook))
self.assertTrue(hook.state.a == 1)
self.assertTrue(hook.state.b == 2)

def test_state_dict(self):
a = 0
b = {1: 2, 3: [4]}
Expand Down