Skip to content

Commit fdb2746

Browse files
committed
Improve the Python type hints for basilisp.lang.multifn
1 parent bb0e011 commit fdb2746

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
* Removed support for PyPy 3.8 (#785)
3434

3535
### Other
36-
* Improve the state of the Python type hints in `basilisp.lang.*` (#797, #784)
37-
36+
* Improve the state of the Python type hints in `basilisp.lang.*` (#797, #784, #8??)
3837

3938
## [v0.1.0b0]
4039
### Added

src/basilisp/lang/multifn.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import threading
2-
from typing import Any, Callable, Generic, Optional, TypeVar
2+
from typing import Any, Callable, Optional, TypeVar
3+
4+
from typing_extensions import Concatenate, Generic, ParamSpec
35

46
from basilisp.lang import map as lmap
57
from basilisp.lang import runtime
@@ -8,15 +10,16 @@
810
from basilisp.lang.set import PersistentSet
911

1012
T = TypeVar("T")
11-
DispatchFunction = Callable[..., T]
12-
Method = Callable[..., Any]
13+
P = ParamSpec("P")
14+
DispatchFunction = Callable[Concatenate[T, P], T]
15+
Method = Callable[Concatenate[T, P], Any]
1316

1417

1518
_GLOBAL_HIERARCHY_SYM = sym.symbol("global-hierarchy", ns=runtime.CORE_NS)
1619
_ISA_SYM = sym.symbol("isa?", ns=runtime.CORE_NS)
1720

1821

19-
class MultiFunction(Generic[T]):
22+
class MultiFunction(Generic[T, P]):
2023
__slots__ = (
2124
"_name",
2225
"_default",
@@ -33,7 +36,7 @@ class MultiFunction(Generic[T]):
3336
def __init__(
3437
self,
3538
name: sym.Symbol,
36-
dispatch: DispatchFunction,
39+
dispatch: DispatchFunction[T, P],
3740
default: T,
3841
hierarchy: Optional[IRef] = None,
3942
) -> None:
@@ -63,11 +66,11 @@ def __init__(
6366
# caches.
6467
self._cached_hierarchy = self._hierarchy.deref()
6568

66-
def __call__(self, *args, **kwargs):
67-
key = self._dispatch(*args, **kwargs)
69+
def __call__(self, v: T, *args: P.args, **kwargs: P.kwargs) -> Any:
70+
key = self._dispatch(v, *args, **kwargs)
6871
method = self.get_method(key)
6972
if method is not None:
70-
return method(*args, **kwargs)
73+
return method(v, *args, **kwargs)
7174
raise NotImplementedError
7275

7376
def _reset_cache(self):
@@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool:
9497
selection."""
9598
return self._has_preference(tag, parent) or self._is_a(tag, parent)
9699

97-
def add_method(self, key: T, method: Method) -> None:
100+
def add_method(self, key: T, method: Method[T, P]) -> None:
98101
"""Add a new method to this function which will respond for key returned from
99102
the dispatch function."""
100103
with self._lock:
101104
self._methods = self._methods.assoc(key, method)
102105
self._reset_cache()
103106

104-
def _find_and_cache_method(self, key: T) -> Optional[Method]:
107+
def _find_and_cache_method(self, key: T) -> Optional[Method[T, P]]:
105108
"""Find and cache the best method for dispatch value `key`."""
106109
with self._lock:
107110
best_key: Optional[T] = None
@@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]:
125128

126129
return best_method
127130

128-
def get_method(self, key: T) -> Optional[Method]:
131+
def get_method(self, key: T) -> Optional[Method[T, P]]:
129132
"""Return the method which would handle this dispatch key or None if no method
130133
defined for this key and no default."""
131134
if self._cached_hierarchy != self._hierarchy.deref():
@@ -159,7 +162,7 @@ def prefers(self):
159162
"""Return a mapping of preferred values to the set of other values."""
160163
return self._prefers
161164

162-
def remove_method(self, key: T) -> Optional[Method]:
165+
def remove_method(self, key: T) -> Optional[Method[T, P]]:
163166
"""Remove the method defined for this key and return it."""
164167
with self._lock:
165168
method = self._methods.val_at(key, None)
@@ -179,5 +182,5 @@ def default(self) -> T:
179182
return self._default
180183

181184
@property
182-
def methods(self) -> IPersistentMap[T, Method]:
185+
def methods(self) -> IPersistentMap[T, Method[T, P]]:
183186
return self._methods

0 commit comments

Comments
 (0)