From 1a8e62c080f45a1caa6848edc7a0fc20860888b2 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Fri, 17 May 2024 15:04:22 +0200 Subject: [PATCH 1/6] feat: add first type annotations --- elasticsearch_dsl/function.py | 4 ++-- elasticsearch_dsl/utils.py | 6 +++--- noxfile.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index ef77ce8e..da994993 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -16,7 +16,7 @@ # under the License. import collections.abc -from typing import Dict +from typing import Dict, Optional, ClassVar from .utils import DslBase @@ -70,7 +70,7 @@ class ScoreFunction(DslBase): "filter": {"type": "query"}, "weight": {}, } - name = None + name: ClassVar[Optional[str]] = None def to_dict(self): d = super().to_dict() diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index da6d4fa7..53fe404f 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,7 +18,7 @@ import collections.abc from copy import copy -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, ClassVar, Union from typing_extensions import Self @@ -210,7 +210,7 @@ class DslMeta(type): For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`. """ - _types = {} + _types: ClassVar[Dict[str, type["DslBase"]]] = {} def __init__(cls, name, bases, attrs): super().__init__(name, bases, attrs) @@ -251,7 +251,7 @@ class DslBase(metaclass=DslMeta): all values in the `must` attribute into Query objects) """ - _param_defs = {} + _param_defs: ClassVar[Dict[str, Dict[str, Union[str, bool]]]] = {} @classmethod def get_dsl_class( diff --git a/noxfile.py b/noxfile.py index 4ebbe717..f90f22f0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,6 +30,7 @@ ) TYPED_FILES = ( + "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", "tests/test_query.py", ) From f3de1b455b9e30480352c7c270ceacf94f2251e3 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Fri, 17 May 2024 15:09:09 +0200 Subject: [PATCH 2/6] feat: add _JSONSafeTypes annotation to to_dict methods --- elasticsearch_dsl/function.py | 23 ++++++++++++++--------- elasticsearch_dsl/utils.py | 7 ++++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index da994993..e946b453 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -18,7 +18,7 @@ import collections.abc from typing import Dict, Optional, ClassVar -from .utils import DslBase +from .utils import DslBase, _JSONSafeTypes # Incomplete annotation to not break query.py tests @@ -72,12 +72,14 @@ class ScoreFunction(DslBase): } name: ClassVar[Optional[str]] = None - def to_dict(self): + def to_dict(self) -> Dict[str, _JSONSafeTypes]: d = super().to_dict() # filter and query dicts should be at the same level as us for k in self._param_defs: - if k in d[self.name]: - d[k] = d[self.name].pop(k) + if self.name is not None: + val = d[self.name] + if isinstance(val, dict) and k in val: + d[k] = val.pop(k) return d @@ -88,12 +90,15 @@ class ScriptScore(ScoreFunction): class BoostFactor(ScoreFunction): name = "boost_factor" - def to_dict(self) -> Dict[str, int]: + def to_dict(self) -> Dict[str, _JSONSafeTypes]: d = super().to_dict() - if "value" in d[self.name]: - d[self.name] = d[self.name].pop("value") - else: - del d[self.name] + if self.name is not None: + val = d[self.name] + if isinstance(val, dict): + if "value" in val: + d[self.name] = val.pop("value") + else: + del d[self.name] return d diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 53fe404f..9eea9fc3 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,12 +18,14 @@ import collections.abc from copy import copy -from typing import Any, Dict, Optional, Type, ClassVar, Union +from typing import Any, Dict, Optional, Type, ClassVar, Union, List from typing_extensions import Self from .exceptions import UnknownDslObject, ValidationException +_JSONSafeTypes = Union[int, bool, str, float, List["_JSONSafeTypes"], Dict[str, "_JSONSafeTypes"]] + SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -356,8 +358,7 @@ def __getattr__(self, name): return AttrDict(value) return value - # TODO: This type annotation can probably be made tighter - def to_dict(self) -> Dict[str, Dict[str, Any]]: + def to_dict(self) -> Dict[str, _JSONSafeTypes]: """ Serialize the DSL object to plain dict """ From 631a65cccbf0dce88451ef74f479f041023ce65a Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Fri, 17 May 2024 15:26:49 +0200 Subject: [PATCH 3/6] feat: add typing for SF function --- elasticsearch_dsl/function.py | 22 ++++++++++++---------- elasticsearch_dsl/utils.py | 1 + 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index e946b453..66b1f9d7 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -16,38 +16,40 @@ # under the License. import collections.abc -from typing import Dict, Optional, ClassVar +from copy import deepcopy +from typing import Dict, Optional, ClassVar, Union, MutableMapping, Any from .utils import DslBase, _JSONSafeTypes -# Incomplete annotation to not break query.py tests -def SF(name_or_sf, **params) -> "ScoreFunction": +def SF(name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], **params: Any) -> "ScoreFunction": # {"script_score": {"script": "_score"}, "filter": {}} - if isinstance(name_or_sf, collections.abc.Mapping): + if isinstance(name_or_sf, collections.abc.MutableMapping): if params: raise ValueError("SF() cannot accept parameters when passing in a dict.") - kwargs = {} - sf = name_or_sf.copy() + + kwargs: Dict[str, Any] = {} + sf = deepcopy(name_or_sf) for k in ScoreFunction._param_defs: if k in name_or_sf: kwargs[k] = sf.pop(k) # not sf, so just filter+weight, which used to be boost factor + sf_params = params if not sf: name = "boost_factor" # {'FUNCTION': {...}} elif len(sf) == 1: - name, params = sf.popitem() + name, sf_params = sf.popitem() else: raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}") # boost factor special case, see elasticsearch #6343 - if not isinstance(params, collections.abc.Mapping): - params = {"value": params} + if not isinstance(sf_params, collections.abc.Mapping): + sf_params = {"value": sf_params} # mix known params (from _param_defs) and from inside the function - kwargs.update(params) + kwargs.update(sf_params) return ScoreFunction.get_dsl_class(name)(**kwargs) # ScriptScore(script="_score", filter=Q()) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 9eea9fc3..11cfca65 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -253,6 +253,7 @@ class DslBase(metaclass=DslMeta): all values in the `must` attribute into Query objects) """ + _type_name: ClassVar[str] _param_defs: ClassVar[Dict[str, Dict[str, Union[str, bool]]]] = {} @classmethod From 33a4b5b9ed9a5bf52f0f9eb9d6dc18ec9d9fd5f5 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Fri, 17 May 2024 23:40:21 +0200 Subject: [PATCH 4/6] chore: fix linting --- elasticsearch_dsl/function.py | 19 +++++++++++++++++-- elasticsearch_dsl/utils.py | 8 +++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index 66b1f9d7..fc74814c 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -17,12 +17,27 @@ import collections.abc from copy import deepcopy -from typing import Dict, Optional, ClassVar, Union, MutableMapping, Any +from typing import Any, ClassVar, Dict, MutableMapping, Optional, Union, overload from .utils import DslBase, _JSONSafeTypes -def SF(name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], **params: Any) -> "ScoreFunction": +@overload +def SF(name_or_sf: MutableMapping[str, Any]) -> "ScoreFunction": ... + + +@overload +def SF(name_or_sf: "ScoreFunction") -> "ScoreFunction": ... + + +@overload +def SF(name_or_sf: str, **params: Any) -> "ScoreFunction": ... + + +def SF( + name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], + **params: Any, +) -> "ScoreFunction": # {"script_score": {"script": "_score"}, "filter": {}} if isinstance(name_or_sf, collections.abc.MutableMapping): if params: diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 11cfca65..797e22b3 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,13 +18,15 @@ import collections.abc from copy import copy -from typing import Any, Dict, Optional, Type, ClassVar, Union, List +from typing import Any, ClassVar, Dict, List, Optional, Type, Union from typing_extensions import Self from .exceptions import UnknownDslObject, ValidationException -_JSONSafeTypes = Union[int, bool, str, float, List["_JSONSafeTypes"], Dict[str, "_JSONSafeTypes"]] +_JSONSafeTypes = Union[ + int, bool, str, float, List["_JSONSafeTypes"], Dict[str, "_JSONSafeTypes"] +] SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -212,7 +214,7 @@ class DslMeta(type): For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`. """ - _types: ClassVar[Dict[str, type["DslBase"]]] = {} + _types: ClassVar[Dict[str, Type["DslBase"]]] = {} def __init__(cls, name, bases, attrs): super().__init__(name, bases, attrs) From 27665566dee84e0d2d56ac05c8859c28c11eee9c Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Tue, 21 May 2024 19:49:41 +0100 Subject: [PATCH 5/6] rename _JSONSafeTypes to JSONType --- elasticsearch_dsl/function.py | 6 +++--- elasticsearch_dsl/utils.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index fc74814c..635e049b 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -19,7 +19,7 @@ from copy import deepcopy from typing import Any, ClassVar, Dict, MutableMapping, Optional, Union, overload -from .utils import DslBase, _JSONSafeTypes +from .utils import DslBase, JSONType @overload @@ -89,7 +89,7 @@ class ScoreFunction(DslBase): } name: ClassVar[Optional[str]] = None - def to_dict(self) -> Dict[str, _JSONSafeTypes]: + def to_dict(self) -> Dict[str, JSONType]: d = super().to_dict() # filter and query dicts should be at the same level as us for k in self._param_defs: @@ -107,7 +107,7 @@ class ScriptScore(ScoreFunction): class BoostFactor(ScoreFunction): name = "boost_factor" - def to_dict(self) -> Dict[str, _JSONSafeTypes]: + def to_dict(self) -> Dict[str, JSONType]: d = super().to_dict() if self.name is not None: val = d[self.name] diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 797e22b3..1bc81379 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -24,8 +24,8 @@ from .exceptions import UnknownDslObject, ValidationException -_JSONSafeTypes = Union[ - int, bool, str, float, List["_JSONSafeTypes"], Dict[str, "_JSONSafeTypes"] +JSONType = Union[ + int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] ] SKIP_VALUES = ("", None) @@ -361,7 +361,7 @@ def __getattr__(self, name): return AttrDict(value) return value - def to_dict(self) -> Dict[str, _JSONSafeTypes]: + def to_dict(self) -> Dict[str, JSONType]: """ Serialize the DSL object to plain dict """ From 0ef931f2f2f9e7b0e55763b09b2be8e3b97d1850 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Tue, 21 May 2024 19:52:49 +0100 Subject: [PATCH 6/6] format code --- elasticsearch_dsl/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 1bc81379..6e311316 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -24,9 +24,7 @@ from .exceptions import UnknownDslObject, ValidationException -JSONType = Union[ - int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] -] +JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] SKIP_VALUES = ("", None) EXPAND__TO_DOT = True