Skip to content

Commit 3074c85

Browse files
VectorStore for GenAI integrations (#2528) (#2540)
* ElasticsearchStore * Update elasticsearch/store/_utilities.py Co-authored-by: Quentin Pradet <[email protected]> * rename; depend on client; async only * generate _sync files * add cleanup step for _sync generation * fix formatting * more linting fixes * batch embedding call; infer num_dimensions * revert accidental changes * keep field names only in store; apply metadata mappings in store * fix typos in file names * use `elasticsearch_url` fixture; create conftest.py * export relevant classes * remove Semantic strategy wait for `semantic_text` to land * es_query is sync * async strategies * cleanup old file * add docker-compose service with model deployment * optional dependencies for MMR * only test sync parts * cleanup unasync script * nox: install optional deps * fix tests with requests remembering Transport * fix numpy typing * add user agent default argument * move to `elasticsearch.helpers.vectorstore` * use Protocol over ABC * revert Protocol change because Python 3.7 * address PR feedback: - Strategy suffix - Sphinx docstrings - add user agent to EmbeddingService - raise ConflictError - various cleanup * improve docstring * fix metadata mappings issue * address PR feedback * add error tests for strategies * canonical names, keyword args only * fix sparse vector strategy bug (duplicate `size`) * all wildcard deletes in compose ES --------- Co-authored-by: Quentin Pradet <[email protected]> (cherry picked from commit c2b0ca3) Co-authored-by: Max Jakob <[email protected]>
1 parent 7d544d4 commit 3074c85

24 files changed

+3543
-39
lines changed

dev-requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ twine
1515
build
1616
nox
1717

18-
numpy
1918
pandas
2019
orjson
2120

21+
# mmr for vectorstore
22+
numpy
23+
simsimd
24+
2225
# Testing the 'search_mvt' API response
2326
mapbox-vector-tile
2427
# Python 3.7 gets an old version of mapbox-vector-tile, requiring an
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch.helpers.vectorstore._async.embedding_service import (
19+
AsyncElasticsearchEmbeddings,
20+
AsyncEmbeddingService,
21+
)
22+
from elasticsearch.helpers.vectorstore._async.strategies import (
23+
AsyncBM25Strategy,
24+
AsyncDenseVectorScriptScoreStrategy,
25+
AsyncDenseVectorStrategy,
26+
AsyncRetrievalStrategy,
27+
AsyncSparseVectorStrategy,
28+
)
29+
from elasticsearch.helpers.vectorstore._async.vectorstore import AsyncVectorStore
30+
from elasticsearch.helpers.vectorstore._sync.embedding_service import (
31+
ElasticsearchEmbeddings,
32+
EmbeddingService,
33+
)
34+
from elasticsearch.helpers.vectorstore._sync.strategies import (
35+
BM25Strategy,
36+
DenseVectorScriptScoreStrategy,
37+
DenseVectorStrategy,
38+
RetrievalStrategy,
39+
SparseVectorStrategy,
40+
)
41+
from elasticsearch.helpers.vectorstore._sync.vectorstore import VectorStore
42+
from elasticsearch.helpers.vectorstore._utils import DistanceMetric
43+
44+
__all__ = [
45+
"AsyncBM25Strategy",
46+
"AsyncDenseVectorScriptScoreStrategy",
47+
"AsyncDenseVectorStrategy",
48+
"AsyncElasticsearchEmbeddings",
49+
"AsyncEmbeddingService",
50+
"AsyncRetrievalStrategy",
51+
"AsyncSparseVectorStrategy",
52+
"AsyncVectorStore",
53+
"BM25Strategy",
54+
"DenseVectorScriptScoreStrategy",
55+
"DenseVectorStrategy",
56+
"DistanceMetric",
57+
"ElasticsearchEmbeddings",
58+
"EmbeddingService",
59+
"RetrievalStrategy",
60+
"SparseVectorStrategy",
61+
"VectorStore",
62+
]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch import AsyncElasticsearch, BadRequestError, NotFoundError
19+
20+
21+
async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> None:
22+
"""
23+
:raises [NotFoundError]: if the model is neither downloaded nor deployed.
24+
:raises [ConflictError]: if the model is downloaded but not yet deployed.
25+
"""
26+
doc = {"text_field": f"test if the model '{model_id}' is deployed"}
27+
try:
28+
await client.ml.infer_trained_model(model_id=model_id, docs=[doc])
29+
except BadRequestError:
30+
# The model is deployed but expects a different input field name.
31+
pass
32+
33+
34+
async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool:
35+
try:
36+
await model_must_be_deployed(client, model_id)
37+
return True
38+
except NotFoundError:
39+
return False
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from abc import ABC, abstractmethod
19+
from typing import List
20+
21+
from elasticsearch import AsyncElasticsearch
22+
from elasticsearch._version import __versionstr__ as lib_version
23+
24+
25+
class AsyncEmbeddingService(ABC):
26+
@abstractmethod
27+
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
28+
"""Generate embeddings for a list of documents.
29+
30+
:param texts: A list of document strings to generate embeddings for.
31+
32+
:return: A list of embeddings, one for each document in the input.
33+
"""
34+
35+
@abstractmethod
36+
async def embed_query(self, query: str) -> List[float]:
37+
"""Generate an embedding for a single query text.
38+
39+
:param text: The query text to generate an embedding for.
40+
41+
:return: The embedding for the input query text.
42+
"""
43+
44+
45+
class AsyncElasticsearchEmbeddings(AsyncEmbeddingService):
46+
"""Elasticsearch as a service for embedding model inference.
47+
48+
You need to have an embedding model downloaded and deployed in Elasticsearch:
49+
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
50+
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
51+
""" # noqa: E501
52+
53+
def __init__(
54+
self,
55+
*,
56+
client: AsyncElasticsearch,
57+
model_id: str,
58+
input_field: str = "text_field",
59+
user_agent: str = f"elasticsearch-py-es/{lib_version}",
60+
):
61+
"""
62+
:param agent_header: user agent header specific to the 3rd party integration.
63+
Used for usage tracking in Elastic Cloud.
64+
:param model_id: The model_id of the model deployed in the Elasticsearch cluster.
65+
:param input_field: The name of the key for the input text field in the
66+
document. Defaults to 'text_field'.
67+
:param client: Elasticsearch client connection. Alternatively specify the
68+
Elasticsearch connection with the other es_* parameters.
69+
"""
70+
# Add integration-specific usage header for tracking usage in Elastic Cloud.
71+
# client.options preserves existing (non-user-agent) headers.
72+
client = client.options(headers={"User-Agent": user_agent})
73+
74+
self.client = client
75+
self.model_id = model_id
76+
self.input_field = input_field
77+
78+
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
79+
return await self._embedding_func(texts)
80+
81+
async def embed_query(self, text: str) -> List[float]:
82+
result = await self._embedding_func([text])
83+
return result[0]
84+
85+
async def _embedding_func(self, texts: List[str]) -> List[List[float]]:
86+
response = await self.client.ml.infer_trained_model(
87+
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
88+
)
89+
return [doc["predicted_value"] for doc in response["inference_results"]]

0 commit comments

Comments
 (0)