1
1
import threading
2
- from typing import Any , Callable , Generic , Optional , TypeVar
2
+ from typing import Any , Callable , Optional , TypeVar
3
+
4
+ from typing_extensions import Concatenate , Generic , ParamSpec
3
5
4
6
from basilisp .lang import map as lmap
5
7
from basilisp .lang import runtime
8
10
from basilisp .lang .set import PersistentSet
9
11
10
12
T = TypeVar ("T" )
11
- DispatchFunction = Callable [..., T ]
12
- Method = Callable [..., Any ]
13
+ P = ParamSpec ("P" )
14
+ DispatchFunction = Callable [Concatenate [T , P ], T ]
15
+ Method = Callable [Concatenate [T , P ], Any ]
13
16
14
17
15
18
_GLOBAL_HIERARCHY_SYM = sym .symbol ("global-hierarchy" , ns = runtime .CORE_NS )
16
19
_ISA_SYM = sym .symbol ("isa?" , ns = runtime .CORE_NS )
17
20
18
21
19
- class MultiFunction (Generic [T ]):
22
+ class MultiFunction (Generic [T , P ]):
20
23
__slots__ = (
21
24
"_name" ,
22
25
"_default" ,
@@ -33,7 +36,7 @@ class MultiFunction(Generic[T]):
33
36
def __init__ (
34
37
self ,
35
38
name : sym .Symbol ,
36
- dispatch : DispatchFunction ,
39
+ dispatch : DispatchFunction [ T , P ] ,
37
40
default : T ,
38
41
hierarchy : Optional [IRef ] = None ,
39
42
) -> None :
@@ -63,11 +66,11 @@ def __init__(
63
66
# caches.
64
67
self ._cached_hierarchy = self ._hierarchy .deref ()
65
68
66
- def __call__ (self , * args , ** kwargs ) :
67
- key = self ._dispatch (* args , ** kwargs )
69
+ def __call__ (self , v : T , * args : P . args , ** kwargs : P . kwargs ) -> Any :
70
+ key = self ._dispatch (v , * args , ** kwargs )
68
71
method = self .get_method (key )
69
72
if method is not None :
70
- return method (* args , ** kwargs )
73
+ return method (v , * args , ** kwargs )
71
74
raise NotImplementedError
72
75
73
76
def _reset_cache (self ):
@@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool:
94
97
selection."""
95
98
return self ._has_preference (tag , parent ) or self ._is_a (tag , parent )
96
99
97
- def add_method (self , key : T , method : Method ) -> None :
100
+ def add_method (self , key : T , method : Method [ T , P ] ) -> None :
98
101
"""Add a new method to this function which will respond for key returned from
99
102
the dispatch function."""
100
103
with self ._lock :
101
104
self ._methods = self ._methods .assoc (key , method )
102
105
self ._reset_cache ()
103
106
104
- def _find_and_cache_method (self , key : T ) -> Optional [Method ]:
107
+ def _find_and_cache_method (self , key : T ) -> Optional [Method [ T , P ] ]:
105
108
"""Find and cache the best method for dispatch value `key`."""
106
109
with self ._lock :
107
110
best_key : Optional [T ] = None
@@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]:
125
128
126
129
return best_method
127
130
128
- def get_method (self , key : T ) -> Optional [Method ]:
131
+ def get_method (self , key : T ) -> Optional [Method [ T , P ] ]:
129
132
"""Return the method which would handle this dispatch key or None if no method
130
133
defined for this key and no default."""
131
134
if self ._cached_hierarchy != self ._hierarchy .deref ():
@@ -159,7 +162,7 @@ def prefers(self):
159
162
"""Return a mapping of preferred values to the set of other values."""
160
163
return self ._prefers
161
164
162
- def remove_method (self , key : T ) -> Optional [Method ]:
165
+ def remove_method (self , key : T ) -> Optional [Method [ T , P ] ]:
163
166
"""Remove the method defined for this key and return it."""
164
167
with self ._lock :
165
168
method = self ._methods .val_at (key , None )
@@ -179,5 +182,5 @@ def default(self) -> T:
179
182
return self ._default
180
183
181
184
@property
182
- def methods (self ) -> IPersistentMap [T , Method ]:
185
+ def methods (self ) -> IPersistentMap [T , Method [ T , P ] ]:
183
186
return self ._methods
0 commit comments