Skip to content

Commit 95b2a59

Browse files
additional typing support
1 parent 8d772e1 commit 95b2a59

File tree

5 files changed

+107
-47
lines changed

5 files changed

+107
-47
lines changed

elasticsearch_dsl/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .aggs import A
2020
from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer
2121
from .document import AsyncDocument, Document
22-
from .document_base import InnerDoc, MetaField
22+
from .document_base import InnerDoc, M, MetaField, mapped_field
2323
from .exceptions import (
2424
ElasticsearchDslException,
2525
IllegalOperation,
@@ -148,6 +148,7 @@
148148
"Keyword",
149149
"Long",
150150
"LongRange",
151+
"M",
151152
"Mapping",
152153
"MetaField",
153154
"MultiSearch",
@@ -178,6 +179,7 @@
178179
"char_filter",
179180
"connections",
180181
"construct_field",
182+
"mapped_field",
181183
"normalizer",
182184
"token_filter",
183185
"tokenizer",

elasticsearch_dsl/document_base.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from datetime import date, datetime
1919
from fnmatch import fnmatch
20-
from typing import List, Optional
20+
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union, overload
2121

2222
from .exceptions import ValidationException
2323
from .field import (
@@ -70,6 +70,34 @@ def __init__(self, name, bases, attrs):
7070
# create the mapping instance
7171
self.mapping = getattr(meta, "mapping", Mapping())
7272

73+
# register the document's fields, which can be given in a few formats:
74+
#
75+
# class MyDocument(Document):
76+
# # required field using native typing
77+
# # (str, int, float, bool, datetime, date)
78+
# field1: str
79+
#
80+
# # optional field using native typing
81+
# field2: Optional[datetime]
82+
#
83+
# # array field using native typing
84+
# field3: list[int]
85+
#
86+
# # sub-object, same as Object(MyInnerDoc)
87+
# field4: MyInnerDoc
88+
#
89+
# # nested sub-objects, same as Nested(MyInnerDoc)
90+
# field5: list[MyInnerDoc]
91+
#
92+
# # use typing, but override with any stock or custom field
93+
# field6: bool = MyCustomField()
94+
#
95+
# # best mypy and pyright typing support
96+
# field7: M[date]
97+
# field8: M[str] = mapped_field(MyCustomText())
98+
#
99+
# # legacy format without Python typing
100+
# field8 = Text()
73101
annotations = attrs.get("__annotations__", {})
74102
fields = set([n for n in attrs if isinstance(attrs[n], Field)])
75103
fields.update(annotations.keys())
@@ -81,25 +109,34 @@ def __init__(self, name, bases, attrs):
81109
required = True
82110
multi = False
83111
while hasattr(type_, "__origin__"):
84-
if type_.__origin__ == Optional:
85-
required = False
112+
if type_.__origin__ == Mapped:
86113
type_ = type_.__args__[0]
87-
elif issubclass(type_.__origin__, List):
114+
elif type_.__origin__ == Union:
115+
if len(type_.__args__) == 2 and type_.__args__[1] is type(None):
116+
required = False
117+
type_ = type_.__args__[0]
118+
else:
119+
raise TypeError("Unsupported union")
120+
elif type_.__origin__ in [list, List]:
88121
multi = True
89122
type_ = type_.__args__[0]
90-
if issubclass(type_, InnerDoc):
123+
else:
124+
break
125+
field_args = []
126+
field_kwargs = {}
127+
if not isinstance(type_, type):
128+
raise TypeError(f"Cannot map type {type_}")
129+
elif issubclass(type_, InnerDoc):
91130
field = Nested if multi else Object
92-
field_args = {}
131+
field_args = [type_]
93132
elif type_ in self.type_annotation_map:
94-
field, field_args = self.type_annotation_map[type_]
133+
field, field_kwargs = self.type_annotation_map[type_]
95134
elif not issubclass(type_, Field):
96135
raise TypeError(f"Cannot map type {type_}")
97136
else:
98137
field = type_
99-
field_args = {}
100-
field_args = {"multi": multi, "required": required, **field_args}
101-
value = field(**field_args)
102-
value._name = name
138+
field_kwargs = {"multi": multi, "required": required, **field_kwargs}
139+
value = field(*field_args, **field_kwargs)
103140
self.mapping.field(name, value)
104141
if name in attrs:
105142
del attrs[name]
@@ -120,6 +157,36 @@ def name(self):
120157
return self.mapping.properties.name
121158

122159

160+
_FieldType = TypeVar("_FieldType")
161+
162+
163+
class Mapped(Generic[_FieldType]):
164+
__slots__ = {}
165+
166+
if TYPE_CHECKING:
167+
168+
@overload
169+
def __get__(self, instance: None, owner: Any) -> InstrumentedField: ...
170+
171+
@overload
172+
def __get__(self, instance: object, owner: Any) -> _FieldType: ...
173+
174+
def __get__(
175+
self, instance: Optional[object], owner: Any
176+
) -> Union[InstrumentedField, _FieldType]: ...
177+
178+
def __set__(self, instance: Optional[object], value: _FieldType) -> None: ...
179+
180+
def __delete__(self, instance: Any) -> None: ...
181+
182+
183+
M = Mapped
184+
185+
186+
def mapped_field(field) -> Any:
187+
return field
188+
189+
123190
class InnerDoc(ObjectBase, metaclass=DocumentMeta):
124191
"""
125192
Common class for inner documents like Object or Nested

elasticsearch_dsl/field.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def __init__(self, multi=False, required=False, *args, **kwargs):
7777
"""
7878
self._multi = multi
7979
self._required = required
80-
self._name = None
81-
self._parent = None
8280
super().__init__(*args, **kwargs)
8381

8482
def __getitem__(self, subfield):

examples/async/vectors.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import asyncio
4848
import json
4949
import os
50+
from datetime import datetime
51+
from typing import List, Optional
5052
from urllib.request import urlopen
5153

5254
import nltk
@@ -55,12 +57,10 @@
5557

5658
from elasticsearch_dsl import (
5759
AsyncDocument,
58-
Date,
5960
DenseVector,
6061
InnerDoc,
6162
Keyword,
62-
Nested,
63-
Text,
63+
M,
6464
async_connections,
6565
)
6666

@@ -72,22 +72,22 @@
7272

7373

7474
class Passage(InnerDoc):
75-
content = Text()
76-
embedding = DenseVector()
75+
content: M[str]
76+
embedding: M[DenseVector]
7777

7878

7979
class WorkplaceDoc(AsyncDocument):
8080
class Index:
8181
name = "workplace_documents"
8282

83-
name = Text()
84-
summary = Text()
85-
content = Text()
86-
created = Date()
87-
updated = Date()
88-
url = Keyword()
89-
category = Keyword()
90-
passages = Nested(Passage)
83+
name: M[str]
84+
summary: M[str]
85+
content: M[str]
86+
created: M[datetime]
87+
updated: M[Optional[datetime]]
88+
url: M[Keyword]
89+
category: M[Keyword]
90+
passages: M[Optional[List[Passage]]]
9191

9292
_model = None
9393

examples/vectors.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,15 @@
4646
import argparse
4747
import json
4848
import os
49+
from datetime import datetime
50+
from typing import List, Optional
4951
from urllib.request import urlopen
5052

5153
import nltk
5254
from sentence_transformers import SentenceTransformer
5355
from tqdm import tqdm
5456

55-
from elasticsearch_dsl import (
56-
Date,
57-
DenseVector,
58-
Document,
59-
InnerDoc,
60-
Keyword,
61-
Nested,
62-
Text,
63-
connections,
64-
)
57+
from elasticsearch_dsl import DenseVector, Document, InnerDoc, Keyword, M, connections
6558

6659
DATASET_URL = "https://github.com/raw/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
6760
MODEL_NAME = "all-MiniLM-L6-v2"
@@ -71,22 +64,22 @@
7164

7265

7366
class Passage(InnerDoc):
74-
content = Text()
75-
embedding = DenseVector()
67+
content: M[str]
68+
embedding: M[DenseVector]
7669

7770

7871
class WorkplaceDoc(Document):
7972
class Index:
8073
name = "workplace_documents"
8174

82-
name = Text()
83-
summary = Text()
84-
content = Text()
85-
created = Date()
86-
updated = Date()
87-
url = Keyword()
88-
category = Keyword()
89-
passages = Nested(Passage)
75+
name: M[str]
76+
summary: M[str]
77+
content: M[str]
78+
created: M[datetime]
79+
updated: M[Optional[datetime]]
80+
url: M[Keyword]
81+
category: M[Keyword]
82+
passages: M[Optional[List[Passage]]]
9083

9184
_model = None
9285

0 commit comments

Comments
 (0)