Skip to content

Commit 896b42c

Browse files
committed
Add subset functionality
`BaseResolver.subresolver()` allows for the resolver to be subsetted based on a list of keys. This might be useful for HPO scenarios.
1 parent bbb9cb3 commit 896b42c

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

src/class_resolver/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ def __iter__(self) -> Iterator[X]:
109109
"""Iterate over the registered elements."""
110110
return iter(self.lookup_dict.values())
111111

112+
def subresolver(self, keys: Iterable[str]) -> "BaseResolver[X, Y]":
113+
"""Create a resolver that's a subset of this one."""
114+
elements = [
115+
self.lookup_str(key)
116+
for key in keys
117+
]
118+
return self.__class__(
119+
elements=elements,
120+
default=self.default,
121+
synonyms=self.synonyms,
122+
suffix=self.suffix,
123+
)
124+
112125
@property
113126
def options(self) -> Set[str]:
114127
"""Return the normalized option names."""
@@ -172,6 +185,17 @@ def register(
172185
def lookup(self, query: Hint[X], default: Optional[X] = None) -> X:
173186
"""Lookup an element."""
174187

188+
def lookup_str(self, query: str) -> X:
189+
"""Lookup an element by name."""
190+
key = self.normalize(query)
191+
if key in self.lookup_dict:
192+
return self.lookup_dict[key]
193+
elif key in self.synonyms:
194+
return self.synonyms[key]
195+
else:
196+
valid_choices = sorted(self.options)
197+
raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}")
198+
175199
def docdata(self, query: Hint[X], *path: str, default: Optional[X] = None):
176200
"""Lookup an element and get its docdata.
177201

src/class_resolver/func.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,7 @@ def lookup(self, query: Hint[X], default: Optional[X] = None) -> X:
2929
elif callable(query):
3030
return query # type: ignore
3131
elif isinstance(query, str):
32-
key = self.normalize(query)
33-
if key in self.lookup_dict:
34-
return self.lookup_dict[key]
35-
elif key in self.synonyms:
36-
return self.synonyms[key]
37-
else:
38-
valid_choices = sorted(self.options)
39-
raise KeyError(f"{query} is an invalid. Try one of: {valid_choices}")
32+
return self.lookup_str(query)
4033
else:
4134
raise TypeError(f"Invalid function: {type(query)} - {query}")
4235

0 commit comments

Comments
 (0)