Skip to content

Improve the Python type hints for basilisp.lang.multifn #800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Fix a bug with `basilisp.edn/write-string` where nested double quotes were not escaped properly (#1071)
* Fix a bug where additional arguments to `basilisp test` CLI subcommand were not being passed correctly to Pytest (#1075)

### Other
* Improve the state of the Python type hints in `basilisp.lang.multifn` (#800)

## [v0.2.3]
### Added
* Added a compiler metadata flag for suppressing warnings when Var indirection is unavoidable (#1052)
Expand Down
29 changes: 16 additions & 13 deletions src/basilisp/lang/multifn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import threading
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Any, Callable, Optional, TypeVar

from typing_extensions import Concatenate, Generic, ParamSpec

from basilisp.lang import map as lmap
from basilisp.lang import runtime
Expand All @@ -8,15 +10,16 @@
from basilisp.lang.set import PersistentSet

T = TypeVar("T")
DispatchFunction = Callable[..., T]
Method = Callable[..., Any]
P = ParamSpec("P")
DispatchFunction = Callable[Concatenate[T, P], T]
Method = Callable[Concatenate[T, P], Any]


_GLOBAL_HIERARCHY_SYM = sym.symbol("global-hierarchy", ns=runtime.CORE_NS)
_ISA_SYM = sym.symbol("isa?", ns=runtime.CORE_NS)


class MultiFunction(Generic[T]):
class MultiFunction(Generic[T, P]):
__slots__ = (
"_name",
"_default",
Expand All @@ -33,7 +36,7 @@ class MultiFunction(Generic[T]):
def __init__(
self,
name: sym.Symbol,
dispatch: DispatchFunction,
dispatch: DispatchFunction[T, P],
default: T,
hierarchy: Optional[IRef] = None,
) -> None:
Expand Down Expand Up @@ -63,11 +66,11 @@ def __init__(
# caches.
self._cached_hierarchy = self._hierarchy.deref()

def __call__(self, *args, **kwargs):
key = self._dispatch(*args, **kwargs)
def __call__(self, v: T, *args: P.args, **kwargs: P.kwargs) -> Any:
key = self._dispatch(v, *args, **kwargs)
method = self.get_method(key)
if method is not None:
return method(*args, **kwargs)
return method(v, *args, **kwargs)
raise NotImplementedError

def _reset_cache(self):
Expand All @@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool:
selection."""
return self._has_preference(tag, parent) or self._is_a(tag, parent)

def add_method(self, key: T, method: Method) -> None:
def add_method(self, key: T, method: Method[T, P]) -> None:
"""Add a new method to this function which will respond for key returned from
the dispatch function."""
with self._lock:
self._methods = self._methods.assoc(key, method)
self._reset_cache()

def _find_and_cache_method(self, key: T) -> Optional[Method]:
def _find_and_cache_method(self, key: T) -> Optional[Method[T, P]]:
"""Find and cache the best method for dispatch value `key`."""
with self._lock:
best_key: Optional[T] = None
Expand All @@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]:

return best_method

def get_method(self, key: T) -> Optional[Method]:
def get_method(self, key: T) -> Optional[Method[T, P]]:
"""Return the method which would handle this dispatch key or None if no method
defined for this key and no default."""
if self._cached_hierarchy != self._hierarchy.deref():
Expand Down Expand Up @@ -159,7 +162,7 @@ def prefers(self):
"""Return a mapping of preferred values to the set of other values."""
return self._prefers

def remove_method(self, key: T) -> Optional[Method]:
def remove_method(self, key: T) -> Optional[Method[T, P]]:
"""Remove the method defined for this key and return it."""
with self._lock:
method = self._methods.val_at(key, None)
Expand All @@ -179,5 +182,5 @@ def default(self) -> T:
return self._default

@property
def methods(self) -> IPersistentMap[T, Method]:
def methods(self) -> IPersistentMap[T, Method[T, P]]:
return self._methods