Skip to content

Commit 9b1778e

Browse files
committed
canonical names, keyword args only
1 parent f32ceb2 commit 9b1778e

File tree

14 files changed

+282
-215
lines changed

14 files changed

+282
-215
lines changed

elasticsearch/helpers/vectorstore/_async/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> N
3131
pass
3232

3333

34-
async def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool:
34+
async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool:
3535
try:
36-
await model_must_be_deployed(es_client, model_id)
36+
await model_must_be_deployed(client, model_id)
3737
return True
3838
except NotFoundError:
3939
return False

elasticsearch/helpers/vectorstore/_async/embedding_service.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class AsyncElasticsearchEmbeddings(AsyncEmbeddingService):
5252

5353
def __init__(
5454
self,
55-
es_client: AsyncElasticsearch,
55+
*,
56+
client: AsyncElasticsearch,
5657
model_id: str,
5758
input_field: str = "text_field",
5859
user_agent: str = f"elasticsearch-py-es/{lib_version}",
@@ -63,14 +64,14 @@ def __init__(
6364
:param model_id: The model_id of the model deployed in the Elasticsearch cluster.
6465
:param input_field: The name of the key for the input text field in the
6566
document. Defaults to 'text_field'.
66-
:param es_client: Elasticsearch client connection. Alternatively specify the
67+
:param client: Elasticsearch client connection. Alternatively specify the
6768
Elasticsearch connection with the other es_* parameters.
6869
"""
6970
# Add integration-specific usage header for tracking usage in Elastic Cloud.
7071
# client.options preserves existing (non-user-agent) headers.
71-
es_client = es_client.options(headers={"User-Agent": user_agent})
72+
client = client.options(headers={"User-Agent": user_agent})
7273

73-
self.es_client = es_client
74+
self.client = client
7475
self.model_id = model_id
7576
self.input_field = input_field
7677

@@ -82,7 +83,7 @@ async def embed_query(self, text: str) -> List[float]:
8283
return result[0]
8384

8485
async def _embedding_func(self, texts: List[str]) -> List[List[float]]:
85-
response = await self.es_client.ml.infer_trained_model(
86+
response = await self.client.ml.infer_trained_model(
8687
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
8788
)
8889
return [doc["predicted_value"] for doc in response["inference_results"]]

elasticsearch/helpers/vectorstore/_async/strategies.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class AsyncRetrievalStrategy(ABC):
2727
@abstractmethod
2828
def es_query(
2929
self,
30+
*,
3031
query: Optional[str],
3132
query_vector: Optional[List[float]],
3233
text_field: str,
@@ -51,6 +52,7 @@ def es_query(
5152
@abstractmethod
5253
def es_mappings_settings(
5354
self,
55+
*,
5456
text_field: str,
5557
vector_field: str,
5658
num_dimensions: Optional[int],
@@ -60,13 +62,15 @@ def es_mappings_settings(
6062
creating inference pipelines or checking if a required model was deployed.
6163
6264
:param client: Elasticsearch client connection.
63-
:param index_name: The name of the Elasticsearch index to create.
64-
:param metadata_mapping: Flat dictionary with field and field type pairs that
65-
describe the schema of the metadata.
65+
:param text_field: The field containing the text data in the index.
66+
:param vector_field: The field containing the vector representations in the index.
67+
:param num_dimensions: If vectors are indexed, how many dimensions do they have.
68+
69+
:return: Dictionary with field and field type pairs that describe the schema.
6670
"""
6771

6872
async def before_index_creation(
69-
self, client: AsyncElasticsearch, text_field: str, vector_field: str
73+
self, *, client: AsyncElasticsearch, text_field: str, vector_field: str
7074
) -> None:
7175
"""
7276
Executes before the index is created. Used for setting up
@@ -101,6 +105,7 @@ def __init__(self, model_id: str = ".elser_model_2"):
101105

102106
def es_query(
103107
self,
108+
*,
104109
query: Optional[str],
105110
query_vector: Optional[List[float]],
106111
text_field: str,
@@ -138,6 +143,7 @@ def es_query(
138143

139144
def es_mappings_settings(
140145
self,
146+
*,
141147
text_field: str,
142148
vector_field: str,
143149
num_dimensions: Optional[int],
@@ -154,7 +160,7 @@ def es_mappings_settings(
154160
return mappings, settings
155161

156162
async def before_index_creation(
157-
self, client: AsyncElasticsearch, text_field: str, vector_field: str
163+
self, *, client: AsyncElasticsearch, text_field: str, vector_field: str
158164
) -> None:
159165
if self.model_id:
160166
await model_must_be_deployed(client, self.model_id)
@@ -183,6 +189,7 @@ class AsyncDenseVectorStrategy(AsyncRetrievalStrategy):
183189

184190
def __init__(
185191
self,
192+
*,
186193
distance: DistanceMetric = DistanceMetric.COSINE,
187194
model_id: Optional[str] = None,
188195
hybrid: bool = False,
@@ -202,6 +209,7 @@ def __init__(
202209

203210
def es_query(
204211
self,
212+
*,
205213
query: Optional[str],
206214
query_vector: Optional[List[float]],
207215
text_field: str,
@@ -236,6 +244,7 @@ def es_query(
236244

237245
def es_mappings_settings(
238246
self,
247+
*,
239248
text_field: str,
240249
vector_field: str,
241250
num_dimensions: Optional[int],
@@ -265,7 +274,7 @@ def es_mappings_settings(
265274
return mappings, {}
266275

267276
async def before_index_creation(
268-
self, client: AsyncElasticsearch, text_field: str, vector_field: str
277+
self, *, client: AsyncElasticsearch, text_field: str, vector_field: str
269278
) -> None:
270279
if self.model_id:
271280
await model_must_be_deployed(client, self.model_id)
@@ -314,6 +323,7 @@ def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None:
314323

315324
def es_query(
316325
self,
326+
*,
317327
query: Optional[str],
318328
query_vector: Optional[List[float]],
319329
text_field: str,
@@ -365,6 +375,7 @@ def es_query(
365375

366376
def es_mappings_settings(
367377
self,
378+
*,
368379
text_field: str,
369380
vector_field: str,
370381
num_dimensions: Optional[int],
@@ -396,6 +407,7 @@ def __init__(
396407

397408
def es_query(
398409
self,
410+
*,
399411
query: Optional[str],
400412
query_vector: Optional[List[float]],
401413
text_field: str,
@@ -423,6 +435,7 @@ def es_query(
423435

424436
def es_mappings_settings(
425437
self,
438+
*,
426439
text_field: str,
427440
vector_field: str,
428441
num_dimensions: Optional[int],

elasticsearch/helpers/vectorstore/_async/vectorstore.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ class AsyncVectorStore:
5050

5151
def __init__(
5252
self,
53-
es_client: AsyncElasticsearch,
54-
index_name: str,
53+
client: AsyncElasticsearch,
54+
*,
55+
index: str,
5556
retrieval_strategy: AsyncRetrievalStrategy,
5657
embedding_service: Optional[AsyncEmbeddingService] = None,
5758
num_dimensions: Optional[int] = None,
@@ -63,26 +64,26 @@ def __init__(
6364
"""
6465
:param user_header: user agent header specific to the 3rd party integration.
6566
Used for usage tracking in Elastic Cloud.
66-
:param index_name: The name of the index to query.
67+
:param index: The name of the index to query.
6768
:param retrieval_strategy: how to index and search the data. See the strategies
6869
module for availble strategies.
6970
:param text_field: Name of the field with the textual data.
7071
:param vector_field: For strategies that perform embedding inference in Python,
7172
the embedding vector goes in this field.
72-
:param es_client: Elasticsearch client connection. Alternatively specify the
73+
:param client: Elasticsearch client connection. Alternatively specify the
7374
Elasticsearch connection with the other es_* parameters.
7475
"""
7576
# Add integration-specific usage header for tracking usage in Elastic Cloud.
7677
# client.options preserves existing (non-user-agent) headers.
77-
es_client = es_client.options(headers={"User-Agent": user_agent})
78+
client = client.options(headers={"User-Agent": user_agent})
7879

7980
if hasattr(retrieval_strategy, "text_field"):
8081
retrieval_strategy.text_field = text_field
8182
if hasattr(retrieval_strategy, "vector_field"):
8283
retrieval_strategy.vector_field = vector_field
8384

84-
self.es_client = es_client
85-
self.index_name = index_name
85+
self.client = client
86+
self.index = index
8687
self.retrieval_strategy = retrieval_strategy
8788
self.embedding_service = embedding_service
8889
self.num_dimensions = num_dimensions
@@ -91,11 +92,12 @@ def __init__(
9192
self.metadata_mappings = metadata_mappings
9293

9394
async def close(self) -> None:
94-
return await self.es_client.close()
95+
return await self.client.close()
9596

9697
async def add_texts(
9798
self,
9899
texts: List[str],
100+
*,
99101
metadatas: Optional[List[Dict[str, Any]]] = None,
100102
vectors: Optional[List[List[float]]] = None,
101103
ids: Optional[List[str]] = None,
@@ -136,7 +138,7 @@ async def add_texts(
136138

137139
request: Dict[str, Any] = {
138140
"_op_type": "index",
139-
"_index": self.index_name,
141+
"_index": self.index,
140142
self.text_field: text,
141143
"metadata": metadata,
142144
"_id": ids[i],
@@ -150,7 +152,7 @@ async def add_texts(
150152
if len(requests) > 0:
151153
try:
152154
success, failed = await async_bulk(
153-
self.es_client,
155+
self.client,
154156
requests,
155157
stats_only=True,
156158
refresh=refresh_indices,
@@ -170,6 +172,7 @@ async def add_texts(
170172

171173
async def delete( # type: ignore[no-untyped-def]
172174
self,
175+
*,
173176
ids: Optional[List[str]] = None,
174177
query: Optional[Dict[str, Any]] = None,
175178
refresh_indices: bool = True,
@@ -191,11 +194,11 @@ async def delete( # type: ignore[no-untyped-def]
191194
try:
192195
if ids:
193196
body = [
194-
{"_op_type": "delete", "_index": self.index_name, "_id": _id}
197+
{"_op_type": "delete", "_index": self.index, "_id": _id}
195198
for _id in ids
196199
]
197200
await async_bulk(
198-
self.es_client,
201+
self.client,
199202
body,
200203
refresh=refresh_indices,
201204
ignore_status=404,
@@ -204,8 +207,8 @@ async def delete( # type: ignore[no-untyped-def]
204207
logger.debug(f"Deleted {len(body)} texts from index")
205208

206209
else:
207-
await self.es_client.delete_by_query(
208-
index=self.index_name,
210+
await self.client.delete_by_query(
211+
index=self.index,
209212
query=query,
210213
refresh=refresh_indices,
211214
**delete_kwargs,
@@ -221,6 +224,7 @@ async def delete( # type: ignore[no-untyped-def]
221224

222225
async def search(
223226
self,
227+
*,
224228
query: Optional[str],
225229
query_vector: Optional[List[float]] = None,
226230
k: int = 4,
@@ -270,8 +274,8 @@ async def search(
270274
query_body = custom_query(query_body, query)
271275
logger.debug(f"Calling custom_query, Query body now: {query_body}")
272276

273-
response = await self.es_client.search(
274-
index=self.index_name,
277+
response = await self.client.search(
278+
index=self.index,
275279
**query_body,
276280
size=k,
277281
source=True,
@@ -282,9 +286,9 @@ async def search(
282286
return hits
283287

284288
async def _create_index_if_not_exists(self) -> None:
285-
exists = await self.es_client.indices.exists(index=self.index_name)
289+
exists = await self.client.indices.exists(index=self.index)
286290
if exists.meta.status == 200:
287-
logger.debug(f"Index {self.index_name} already exists. Skipping creation.")
291+
logger.debug(f"Index {self.index} already exists. Skipping creation.")
288292
return
289293

290294
if self.retrieval_strategy.needs_inference():
@@ -312,14 +316,17 @@ async def _create_index_if_not_exists(self) -> None:
312316
mappings["properties"]["metadata"] = {"properties": metadata}
313317

314318
await self.retrieval_strategy.before_index_creation(
315-
self.es_client, self.text_field, self.vector_field
319+
client=self.client,
320+
text_field=self.text_field,
321+
vector_field=self.vector_field,
316322
)
317-
await self.es_client.indices.create(
318-
index=self.index_name, mappings=mappings, settings=settings
323+
await self.client.indices.create(
324+
index=self.index, mappings=mappings, settings=settings
319325
)
320326

321327
async def max_marginal_relevance_search(
322328
self,
329+
*,
323330
embedding_service: AsyncEmbeddingService,
324331
query: str,
325332
vector_field: str,

elasticsearch/helpers/vectorstore/_sync/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None:
3131
pass
3232

3333

34-
def model_is_deployed(es_client: Elasticsearch, model_id: str) -> bool:
34+
def model_is_deployed(client: Elasticsearch, model_id: str) -> bool:
3535
try:
36-
model_must_be_deployed(es_client, model_id)
36+
model_must_be_deployed(client, model_id)
3737
return True
3838
except NotFoundError:
3939
return False

elasticsearch/helpers/vectorstore/_sync/embedding_service.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class ElasticsearchEmbeddings(EmbeddingService):
5252

5353
def __init__(
5454
self,
55-
es_client: Elasticsearch,
55+
*,
56+
client: Elasticsearch,
5657
model_id: str,
5758
input_field: str = "text_field",
5859
user_agent: str = f"elasticsearch-py-es/{lib_version}",
@@ -63,14 +64,14 @@ def __init__(
6364
:param model_id: The model_id of the model deployed in the Elasticsearch cluster.
6465
:param input_field: The name of the key for the input text field in the
6566
document. Defaults to 'text_field'.
66-
:param es_client: Elasticsearch client connection. Alternatively specify the
67+
:param client: Elasticsearch client connection. Alternatively specify the
6768
Elasticsearch connection with the other es_* parameters.
6869
"""
6970
# Add integration-specific usage header for tracking usage in Elastic Cloud.
7071
# client.options preserves existing (non-user-agent) headers.
71-
es_client = es_client.options(headers={"User-Agent": user_agent})
72+
client = client.options(headers={"User-Agent": user_agent})
7273

73-
self.es_client = es_client
74+
self.client = client
7475
self.model_id = model_id
7576
self.input_field = input_field
7677

@@ -82,7 +83,7 @@ def embed_query(self, text: str) -> List[float]:
8283
return result[0]
8384

8485
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
85-
response = self.es_client.ml.infer_trained_model(
86+
response = self.client.ml.infer_trained_model(
8687
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
8788
)
8889
return [doc["predicted_value"] for doc in response["inference_results"]]

0 commit comments

Comments
 (0)