Skip to content

[Backport 8.x] Add type hints to wrappers.py #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@

import collections.abc
from copy import copy
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union

from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from .exceptions import UnknownDslObject, ValidationException

JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
# Usefull types

JSONType: TypeAlias = Union[
int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]
]


# Type variables for internals

_KeyT = TypeVar("_KeyT")
_ValT = TypeVar("_ValT")

# Constants

SKIP_VALUES = ("", None)
EXPAND__TO_DOT = True
Expand Down Expand Up @@ -110,18 +122,20 @@ def to_list(self):
return self._l_


class AttrDict:
class AttrDict(Generic[_KeyT, _ValT]):
"""
Helper class to provide attribute like access (read and write) to
dictionaries. Used to provide a convenient way to access both results and
nested dsl dicts.
"""

def __init__(self, d):
_d_: Dict[_KeyT, _ValT]

def __init__(self, d: Dict[_KeyT, _ValT]):
# assign the inner dict manually to prevent __setattr__ from firing
super().__setattr__("_d_", d)

def __contains__(self, key):
def __contains__(self, key: object) -> bool:
return key in self._d_

def __nonzero__(self):
Expand Down
63 changes: 52 additions & 11 deletions elasticsearch_dsl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,61 @@
# under the License.

import operator
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
Dict,
Literal,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

if TYPE_CHECKING:
from _operator import _SupportsComparison

from typing_extensions import TypeAlias

from .utils import AttrDict

ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
RangeValT = TypeVar("RangeValT", bound="_SupportsComparison")

__all__ = ["Range"]


class Range(AttrDict):
OPS = {
class Range(AttrDict[ComparisonOperators, RangeValT]):
OPS: ClassVar[
Mapping[
ComparisonOperators,
Callable[["_SupportsComparison", "_SupportsComparison"], bool],
]
] = {
"lt": operator.lt,
"lte": operator.le,
"gt": operator.gt,
"gte": operator.ge,
}

def __init__(self, *args, **kwargs):
if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)):
def __init__(
self,
d: Optional[Dict[ComparisonOperators, RangeValT]] = None,
/,
**kwargs: RangeValT,
):
if d is not None and (kwargs or not isinstance(d, dict)):
raise ValueError(
"Range accepts a single dictionary or a set of keyword arguments."
)
data = args[0] if args else kwargs

if d is None:
data = cast(Dict[ComparisonOperators, RangeValT], kwargs)
else:
data = d

for k in data:
if k not in self.OPS:
Expand All @@ -47,30 +82,36 @@ def __init__(self, *args, **kwargs):
if "lt" in data and "lte" in data:
raise ValueError("You cannot specify both lt and lte for Range.")

super().__init__(args[0] if args else kwargs)
super().__init__(data)

def __repr__(self):
def __repr__(self) -> str:
return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items())

def __contains__(self, item):
def __contains__(self, item: object) -> bool:
if isinstance(item, str):
return super().__contains__(item)

item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS)
if not item_supports_comp:
return False

for op in self.OPS:
if op in self._d_ and not self.OPS[op](item, self._d_[op]):
if op in self._d_ and not self.OPS[op](
cast("_SupportsComparison", item), self._d_[op]
):
return False
return True

@property
def upper(self):
def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "lt" in self._d_:
return self._d_["lt"], False
if "lte" in self._d_:
return self._d_["lte"], True
return None, False

@property
def lower(self):
def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "gt" in self._d_:
return self._d_["gt"], False
if "gte" in self._d_:
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
TYPED_FILES = (
"elasticsearch_dsl/function.py",
"elasticsearch_dsl/query.py",
"elasticsearch_dsl/wrappers.py",
"tests/test_query.py",
"tests/test_wrappers.py",
)


Expand Down
28 changes: 23 additions & 5 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
# under the License.

from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence

if TYPE_CHECKING:
from _operator import _SupportsComparison

import pytest

Expand All @@ -34,7 +38,9 @@
({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_contains(kwargs, item):
def test_range_contains(
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
) -> None:
assert item in Range(**kwargs)


Expand All @@ -48,7 +54,9 @@ def test_range_contains(kwargs, item):
({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_not_contains(kwargs, item):
def test_range_not_contains(
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
) -> None:
assert item not in Range(**kwargs)


Expand All @@ -62,7 +70,9 @@ def test_range_not_contains(kwargs, item):
((), {"gt": 1, "gte": 1}),
],
)
def test_range_raises_value_error_on_wrong_params(args, kwargs):
def test_range_raises_value_error_on_wrong_params(
args: Sequence[Any], kwargs: Mapping[str, "_SupportsComparison"]
) -> None:
with pytest.raises(ValueError):
Range(*args, **kwargs)

Expand All @@ -76,7 +86,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs):
(Range(lt=42), None, False),
],
)
def test_range_lower(range, lower, inclusive):
def test_range_lower(
range: Range["_SupportsComparison"],
lower: Optional["_SupportsComparison"],
inclusive: bool,
) -> None:
assert (lower, inclusive) == range.lower


Expand All @@ -89,5 +103,9 @@ def test_range_lower(range, lower, inclusive):
(Range(gt=42), None, False),
],
)
def test_range_upper(range, upper, inclusive):
def test_range_upper(
range: Range["_SupportsComparison"],
upper: Optional["_SupportsComparison"],
inclusive: bool,
) -> None:
assert (upper, inclusive) == range.upper
Loading