Skip to content

Commit ac57d5d

Browse files
authored
Use orchestration lib (#22)
* Utilize VectorStore from Elasticsearch client lib * remove non-existing dependency group installation * type ignores for mocking methods * bring back integration_test dependency group * bring back ElasticsearchStore.connect_to_elasticsearch * make custom_query and doc_builder kw-only * keep original strategies * deprecate old retrieval classes * bring back previous integration tests
1 parent 3f96979 commit ac57d5d

13 files changed

+1165
-905
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,40 @@
1+
from elasticsearch.helpers.vectorstore import (
2+
BM25Strategy,
3+
DenseVectorScriptScoreStrategy,
4+
DenseVectorStrategy,
5+
DistanceMetric,
6+
RetrievalStrategy,
7+
SparseVectorStrategy,
8+
)
9+
110
from langchain_elasticsearch.cache import ElasticsearchCache
211
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
312
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
413
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
514
from langchain_elasticsearch.vectorstores import (
615
ApproxRetrievalStrategy,
16+
BM25RetrievalStrategy,
717
ElasticsearchStore,
818
ExactRetrievalStrategy,
919
SparseRetrievalStrategy,
1020
)
1121

1222
__all__ = [
13-
"ApproxRetrievalStrategy",
1423
"ElasticsearchCache",
1524
"ElasticsearchChatMessageHistory",
1625
"ElasticsearchEmbeddings",
1726
"ElasticsearchRetriever",
1827
"ElasticsearchStore",
28+
# retrieval strategies
29+
"BM25Strategy",
30+
"DenseVectorScriptScoreStrategy",
31+
"DenseVectorStrategy",
32+
"DistanceMetric",
33+
"RetrievalStrategy",
34+
"SparseVectorStrategy",
35+
# deprecated retrieval strategies
36+
"ApproxRetrievalStrategy",
37+
"BM25RetrievalStrategy",
1938
"ExactRetrievalStrategy",
2039
"SparseRetrievalStrategy",
2140
]
Lines changed: 5 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from enum import Enum
2-
from typing import List, Union
32

4-
import numpy as np
53
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
64
from langchain_core import __version__ as langchain_version
75

8-
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
9-
106

117
class DistanceStrategy(str, Enum):
128
"""Enumerator of the Distance strategies for calculating distances
@@ -19,77 +15,16 @@ class DistanceStrategy(str, Enum):
1915
COSINE = "COSINE"
2016

2117

18+
def user_agent(prefix: str) -> str:
19+
return f"{prefix}/{langchain_version}"
20+
21+
2222
def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
2323
headers = dict(client._headers)
24-
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
24+
headers.update({"user-agent": f"{user_agent(header_prefix)}"})
2525
return client.options(headers=headers)
2626

2727

28-
def maximal_marginal_relevance(
29-
query_embedding: np.ndarray,
30-
embedding_list: list,
31-
lambda_mult: float = 0.5,
32-
k: int = 4,
33-
) -> List[int]:
34-
"""Calculate maximal marginal relevance."""
35-
if min(k, len(embedding_list)) <= 0:
36-
return []
37-
if query_embedding.ndim == 1:
38-
query_embedding = np.expand_dims(query_embedding, axis=0)
39-
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
40-
most_similar = int(np.argmax(similarity_to_query))
41-
idxs = [most_similar]
42-
selected = np.array([embedding_list[most_similar]])
43-
while len(idxs) < min(k, len(embedding_list)):
44-
best_score = -np.inf
45-
idx_to_add = -1
46-
similarity_to_selected = cosine_similarity(embedding_list, selected)
47-
for i, query_score in enumerate(similarity_to_query):
48-
if i in idxs:
49-
continue
50-
redundant_score = max(similarity_to_selected[i])
51-
equation_score = (
52-
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
53-
)
54-
if equation_score > best_score:
55-
best_score = equation_score
56-
idx_to_add = i
57-
idxs.append(idx_to_add)
58-
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
59-
return idxs
60-
61-
62-
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
63-
"""Row-wise cosine similarity between two equal-width matrices."""
64-
if len(X) == 0 or len(Y) == 0:
65-
return np.array([])
66-
67-
X = np.array(X)
68-
Y = np.array(Y)
69-
if X.shape[1] != Y.shape[1]:
70-
raise ValueError(
71-
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
72-
f"and Y has shape {Y.shape}."
73-
)
74-
try:
75-
import simsimd as simd # type: ignore
76-
77-
X = np.array(X, dtype=np.float32)
78-
Y = np.array(Y, dtype=np.float32)
79-
Z = 1 - simd.cdist(X, Y, metric="cosine")
80-
if isinstance(Z, float):
81-
return np.array([Z])
82-
return np.array(Z)
83-
except ImportError:
84-
X_norm = np.linalg.norm(X, axis=1)
85-
Y_norm = np.linalg.norm(Y, axis=1)
86-
# Ignore divide by zero errors run time warnings as those are handled below.
87-
with np.errstate(divide="ignore", invalid="ignore"):
88-
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
89-
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
90-
return similarity
91-
92-
9328
def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None:
9429
try:
9530
dummy = {"x": "y"}
@@ -106,11 +41,3 @@ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None:
10641
# This error is expected because we do not know the expected document
10742
# shape and just use a dummy doc above.
10843
pass
109-
110-
111-
def model_is_deployed(es_client: Elasticsearch, model_id: str) -> bool:
112-
try:
113-
model_must_be_deployed(es_client, model_id)
114-
return True
115-
except NotFoundError:
116-
return False

libs/elasticsearch/langchain_elasticsearch/embeddings.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, List, Optional
44

55
from elasticsearch import Elasticsearch
6+
from elasticsearch.helpers.vectorstore import EmbeddingService
67
from langchain_core.embeddings import Embeddings
78
from langchain_core.utils import get_from_env
89

@@ -206,3 +207,45 @@ def embed_query(self, text: str) -> List[float]:
206207
List[float]: The embedding for the input query text.
207208
"""
208209
return self._embedding_func([text])[0]
210+
211+
212+
class EmbeddingServiceAdapter(EmbeddingService):
213+
"""
214+
Adapter for LangChain Embeddings to support the EmbeddingService interface from
215+
elasticsearch.helpers.vectorstore.
216+
"""
217+
218+
def __init__(self, langchain_embeddings: Embeddings):
219+
self._langchain_embeddings = langchain_embeddings
220+
221+
def __eq__(self, other): # type: ignore[no-untyped-def]
222+
if isinstance(other, self.__class__):
223+
return self.__dict__ == other.__dict__
224+
else:
225+
return False
226+
227+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
228+
"""
229+
Generate embeddings for a list of documents.
230+
231+
Args:
232+
texts (List[str]): A list of document text strings to generate embeddings
233+
for.
234+
235+
Returns:
236+
List[List[float]]: A list of embeddings, one for each document in the input
237+
list.
238+
"""
239+
return self._langchain_embeddings.embed_documents(texts)
240+
241+
def embed_query(self, text: str) -> List[float]:
242+
"""
243+
Generate an embedding for a single query text.
244+
245+
Args:
246+
text (str): The query text to generate an embedding for.
247+
248+
Returns:
249+
List[float]: The embedding for the input query text.
250+
"""
251+
return self._langchain_embeddings.embed_query(text)

0 commit comments

Comments
 (0)