Skip to content

Commit 46e1c12

Browse files
committed
deprecate old retrieval classes
1 parent 8ce55b7 commit 46e1c12

File tree

4 files changed

+89
-3
lines changed

4 files changed

+89
-3
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,40 @@
1+
from elasticsearch.helpers.vectorstore import (
2+
BM25Strategy,
3+
DenseVectorScriptScoreStrategy,
4+
DenseVectorStrategy,
5+
DistanceMetric,
6+
RetrievalStrategy,
7+
SparseVectorStrategy,
8+
)
9+
110
from langchain_elasticsearch.cache import ElasticsearchCache
211
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
312
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
413
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
514
from langchain_elasticsearch.vectorstores import (
615
ApproxRetrievalStrategy,
16+
BM25RetrievalStrategy,
717
ElasticsearchStore,
818
ExactRetrievalStrategy,
919
SparseRetrievalStrategy,
1020
)
1121

1222
__all__ = [
13-
"ApproxRetrievalStrategy",
1423
"ElasticsearchCache",
1524
"ElasticsearchChatMessageHistory",
1625
"ElasticsearchEmbeddings",
1726
"ElasticsearchRetriever",
1827
"ElasticsearchStore",
28+
# retrieval strategies
29+
"BM25Strategy",
30+
"DenseVectorScriptScoreStrategy",
31+
"DenseVectorStrategy",
32+
"DistanceMetric",
33+
"RetrievalStrategy",
34+
"SparseVectorStrategy",
35+
# deprecated retrieval strategies
36+
"ApproxRetrievalStrategy",
37+
"BM25RetrievalStrategy",
1938
"ExactRetrievalStrategy",
2039
"SparseRetrievalStrategy",
2140
]

libs/elasticsearch/langchain_elasticsearch/vectorstores.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from elasticsearch.helpers.vectorstore import (
2525
VectorStore as EVectorStore,
2626
)
27+
from langchain_core._api.deprecation import deprecated
2728
from langchain_core.documents import Document
2829
from langchain_core.embeddings import Embeddings
2930
from langchain_core.vectorstores import VectorStore
@@ -39,6 +40,7 @@
3940
logger = logging.getLogger(__name__)
4041

4142

43+
@deprecated("0.1.4", alternative="RetrievalStrategy", pending=True)
4244
class BaseRetrievalStrategy(ABC):
4345
"""Base class for `Elasticsearch` retrieval strategies."""
4446

@@ -124,6 +126,7 @@ def require_inference(self) -> bool:
124126
return True
125127

126128

129+
@deprecated("0.1.4", alternative="DenseVectorStrategy", pending=True)
127130
class ApproxRetrievalStrategy(BaseRetrievalStrategy):
128131
"""Approximate retrieval strategy using the `HNSW` algorithm."""
129132

@@ -251,6 +254,7 @@ def index(
251254
}
252255

253256

257+
@deprecated("0.1.4", alternative="DenseVectorScriptScoreStrategy", pending=True)
254258
class ExactRetrievalStrategy(BaseRetrievalStrategy):
255259
"""Exact retrieval strategy using the `script_score` query."""
256260

@@ -319,6 +323,7 @@ def index(
319323
}
320324

321325

326+
@deprecated("0.1.4", alternative="SparseVectorStrategy", pending=True)
322327
class SparseRetrievalStrategy(BaseRetrievalStrategy):
323328
"""Sparse retrieval strategy using the `text_expansion` processor."""
324329

@@ -403,6 +408,7 @@ def require_inference(self) -> bool:
403408
return False
404409

405410

411+
@deprecated("0.1.4", alternative="BM25Strategy", pending=True)
406412
class BM25RetrievalStrategy(BaseRetrievalStrategy):
407413
"""Retrieval strategy using the native BM25 algorithm of Elasticsearch."""
408414

@@ -474,9 +480,14 @@ def require_inference(self) -> bool:
474480

475481

476482
def _convert_retrieval_strategy(
477-
langchain_strategy: BaseRetrievalStrategy, distance: DistanceStrategy
483+
langchain_strategy: BaseRetrievalStrategy,
484+
distance: Optional[DistanceStrategy] = None,
478485
) -> RetrievalStrategy:
479486
if isinstance(langchain_strategy, ApproxRetrievalStrategy):
487+
if distance is None:
488+
raise ValueError(
489+
"ApproxRetrievalStrategy requires a distance strategy to be provided."
490+
)
480491
return DenseVectorStrategy(
481492
distance=DistanceMetric[distance],
482493
model_id=langchain_strategy.query_model_id,
@@ -488,6 +499,10 @@ def _convert_retrieval_strategy(
488499
rrf=False if langchain_strategy.rrf is None else langchain_strategy.rrf,
489500
)
490501
elif isinstance(langchain_strategy, ExactRetrievalStrategy):
502+
if distance is None:
503+
raise ValueError(
504+
"ExactRetrievalStrategy requires a distance strategy to be provided."
505+
)
491506
return DenseVectorScriptScoreStrategy(distance=DistanceMetric[distance])
492507
elif isinstance(langchain_strategy, SparseRetrievalStrategy):
493508
return SparseVectorStrategy(langchain_strategy.model_id)

libs/elasticsearch/tests/unit_tests/test_imports.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from langchain_elasticsearch import __all__
22

33
EXPECTED_ALL = [
4-
"ApproxRetrievalStrategy",
54
"ElasticsearchCache",
65
"ElasticsearchChatMessageHistory",
76
"ElasticsearchEmbeddings",
87
"ElasticsearchRetriever",
98
"ElasticsearchStore",
9+
# retrieval strategies
10+
"BM25Strategy",
11+
"DenseVectorScriptScoreStrategy",
12+
"DenseVectorStrategy",
13+
"DistanceMetric",
14+
"RetrievalStrategy",
15+
"SparseVectorStrategy",
16+
# deprecated retrieval strategies
17+
"ApproxRetrievalStrategy",
18+
"BM25RetrievalStrategy",
1019
"ExactRetrievalStrategy",
1120
"SparseRetrievalStrategy",
1221
]

libs/elasticsearch/tests/unit_tests/test_vectorstores.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@
1111
from langchain_elasticsearch.embeddings import Embeddings, EmbeddingServiceAdapter
1212
from langchain_elasticsearch.vectorstores import (
1313
ApproxRetrievalStrategy,
14+
BM25RetrievalStrategy,
15+
BM25Strategy,
16+
DenseVectorScriptScoreStrategy,
17+
DenseVectorStrategy,
18+
DistanceMetric,
19+
DistanceStrategy,
1420
ElasticsearchStore,
21+
ExactRetrievalStrategy,
22+
SparseRetrievalStrategy,
23+
SparseVectorStrategy,
24+
_convert_retrieval_strategy,
1525
_hits_to_docs_scores,
1626
)
1727

@@ -143,6 +153,39 @@ def test_doc_field_to_metadata(self) -> None:
143153
assert actual == expected
144154

145155

156+
class TestConvertStrategy:
157+
def test_dense_approx(self) -> None:
158+
actual = _convert_retrieval_strategy(
159+
ApproxRetrievalStrategy(query_model_id="my model", hybrid=True, rrf=False),
160+
distance=DistanceStrategy.DOT_PRODUCT,
161+
)
162+
assert isinstance(actual, DenseVectorStrategy)
163+
assert actual.distance == DistanceMetric.DOT_PRODUCT
164+
assert actual.model_id == "my model"
165+
assert actual.hybrid is True
166+
assert actual.rrf is False
167+
168+
def test_dense_exact(self) -> None:
169+
actual = _convert_retrieval_strategy(
170+
ExactRetrievalStrategy(), distance=DistanceStrategy.EUCLIDEAN_DISTANCE
171+
)
172+
assert isinstance(actual, DenseVectorScriptScoreStrategy)
173+
assert actual.distance == DistanceMetric.EUCLIDEAN_DISTANCE
174+
175+
def test_sparse(self) -> None:
176+
actual = _convert_retrieval_strategy(
177+
SparseRetrievalStrategy(model_id="my model ID")
178+
)
179+
assert isinstance(actual, SparseVectorStrategy)
180+
assert actual.model_id == "my model ID"
181+
182+
def test_bm25(self) -> None:
183+
actual = _convert_retrieval_strategy(BM25RetrievalStrategy(k1=1.7, b=5.4))
184+
assert isinstance(actual, BM25Strategy)
185+
assert actual.k1 == 1.7
186+
assert actual.b == 5.4
187+
188+
146189
class TestVectorStore:
147190
@pytest.fixture
148191
def embeddings(self) -> Embeddings:

0 commit comments

Comments
 (0)