Skip to content

Add simple resolver #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/class_resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class is ``Algorithm`` and it can infer what you mean.
RegistrationError,
RegistrationNameConflict,
RegistrationSynonymConflict,
SimpleResolver,
)
from .func import FunctionResolver
from .utils import (
Expand Down Expand Up @@ -93,6 +94,7 @@ class is ``Algorithm`` and it can infer what you mean.
"Resolver",
"ClassResolver",
"FunctionResolver",
"SimpleResolver",
# Utilities
"get_cls",
"get_subclasses",
Expand Down
53 changes: 53 additions & 0 deletions src/class_resolver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,56 @@ def objective(trial: optuna.Trial) -> float:
"""
key = trial.suggest_categorical(name, sorted(self.lookup_dict))
return self.lookup(key)


class SimpleResolver(BaseResolver[X, X], Generic[X]):
"""
A simple resolver which uses the string representations as key.

While very minimalistic, it can be quite handy when dealing with simple objects, e.g.,

>>> log_level_resolver = SimpleResolver(["debug", "info", "warning", "error"], default="info")
>>> log_level_resolver.make(None)
"info"
>>> r.make("WARNING")
"warning"
>>> r.make("fatal")
Traceback (most recent call last):
...
ValueError: Invalid query=fatal. Possible queries are {"debug", "info", "warning", "error"}.

We can also benefit from, e.g., creation of command-line options for click

>>> log_level_option = log_level_resolver.get_option("--log-level")

Or use the resolver to ensure a type-safe normalization

>>> import typing
>>> LogLevel = typing.Literal["debug", "info", "warning", "error"]
>>> r: SimpleResolver[LogLevel] = SimpleResolver(["debug", "info", "warning", "error"], default="info")
"""

# docstr-coverage: inherited
def extract_name(self, element: X) -> str: # noqa: D102
return str(element)

# docstr-coverage: inherited
def lookup(self, query: Hint[X], default: Optional[X] = None) -> X: # noqa: D102
str_query = self.normalize(str(query))
if str_query in self.lookup_dict:
return self.lookup_dict[str_query]
if query is not None:
raise ValueError(f"Invalid query={query}. Possible queries are {self.options}.")
if default is not None:
return default
if self.default is not None:
return self.default
raise ValueError(
"If query and default are None, a default must be set in the resolver, but it is None, too."
)

# docstr-coverage: inherited
def make(self, query, pos_kwargs: OptionalKwargs = None, **kwargs) -> X: # noqa: D102
if pos_kwargs is not None:
raise ValueError(f"{self.__class__.__name__} does not support positional arguments.")
return self.lookup(query=query, **kwargs)
28 changes: 28 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RegistrationNameConflict,
RegistrationSynonymConflict,
Resolver,
SimpleResolver,
UnexpectedKeywordError,
)

Expand Down Expand Up @@ -490,3 +491,30 @@ class AAlt3Base(Alt3Base):
with self.assertRaises(TypeError) as e:
resolver.make("a")
self.assertEqual("surprise!", str(e.exception))


class TestSimpleResolver(unittest.TestCase):
"""Tests for the simple resolver."""

def setUp(self) -> None:
"""Create test instance."""
self.instance = SimpleResolver([0, 1, 2, 3])

def test_make(self):
"""Test making valid objects."""
for i in range(4):
self.assertEqual(self.instance.make(i), i)
self.assertEqual(self.instance.make(str(i)), i)

def test_make_invalid(self):
"""Test making invalid choices."""
with self.assertRaises(ValueError):
self.instance.make(-1)
with self.assertRaises(ValueError):
self.instance.make(4)

def test_default(self):
"""Test make's interaction with default."""
with self.assertRaises(ValueError):
self.instance.make(None)
self.assertEqual(self.instance.make(None, default=2), 2)