Skip to content

Commit d397982

Browse files
committed
use elasticsearch_url fixture; create conftest.py
1 parent 7647961 commit d397982

File tree

13 files changed

+211
-1217
lines changed

13 files changed

+211
-1217
lines changed

elasticsearch/vectorstore/_async/strategies.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from abc import ABC, abstractmethod
1919
from enum import Enum
20-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast
20+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
2121

2222
from elasticsearch import AsyncElasticsearch
2323
from elasticsearch.vectorstore._async._utils import model_must_be_deployed
@@ -250,7 +250,6 @@ class DenseVector(RetrievalStrategy):
250250

251251
def __init__(
252252
self,
253-
knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw",
254253
distance: DistanceMetric = DistanceMetric.COSINE,
255254
model_id: Optional[str] = None,
256255
hybrid: bool = False,
@@ -262,7 +261,6 @@ def __init__(
262261
"to enable hybrid you have to specify a text_field (for BM25 matching)"
263262
)
264263

265-
self.knn_type = knn_type
266264
self.distance = distance
267265
self.model_id = model_id
268266
self.hybrid = hybrid

elasticsearch/vectorstore/_sync/strategies.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from abc import ABC, abstractmethod
1919
from enum import Enum
20-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast
20+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
2121

2222
from elasticsearch import Elasticsearch
2323
from elasticsearch.vectorstore._sync._utils import model_must_be_deployed
@@ -250,7 +250,6 @@ class DenseVector(RetrievalStrategy):
250250

251251
def __init__(
252252
self,
253-
knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw",
254253
distance: DistanceMetric = DistanceMetric.COSINE,
255254
model_id: Optional[str] = None,
256255
hybrid: bool = False,
@@ -262,7 +261,6 @@ def __init__(
262261
"to enable hybrid you have to specify a text_field (for BM25 matching)"
263262
)
264263

265-
self.knn_type = knn_type
266264
self.distance = distance
267265
self.model_id = model_id
268266
self.hybrid = hybrid

elasticsearch/vectorstore/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424

2525
def maximal_marginal_relevance(
26-
query_embedding: list[float],
27-
embedding_list: list[list[float]],
26+
query_embedding: List[float],
27+
embedding_list: List[List[float]],
2828
lambda_mult: float = 0.5,
2929
k: int = 4,
3030
) -> List[int]:

setup.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -94,77 +94,3 @@
9494
"orjson": ["orjson>=3"],
9595
},
9696
)
97-
98-
vectorstore_package_name = "elasticsearch[vectorstore]"
99-
base_dir = abspath(dirname(__file__))
100-
101-
with open(join(base_dir, package_name, "_version.py")) as f:
102-
package_version = re.search(
103-
r"__versionstr__\s+=\s+[\"\']([^\"\']+)[\"\']", f.read()
104-
).group(1)
105-
106-
with open(join(base_dir, "README.rst")) as f:
107-
# Remove reST raw directive from README as they're not allowed on PyPI
108-
# Those blocks start with a newline and continue until the next newline
109-
mode = None
110-
lines = []
111-
for line in f:
112-
if line.startswith(".. raw::"):
113-
mode = "ignore_nl"
114-
elif line == "\n":
115-
mode = "wait_nl" if mode == "ignore_nl" else None
116-
if mode is None:
117-
lines.append(line)
118-
119-
long_description = "".join(lines)
120-
121-
122-
packages = [
123-
package
124-
for package in find_packages(where=".", exclude=("test_elasticsearch*",))
125-
if package == package_name or package.startswith(package_name + ".")
126-
]
127-
128-
setup(
129-
name=package_name,
130-
description="Python client for Elasticsearch",
131-
license="Apache-2.0",
132-
url="https://github.com/elastic/elasticsearch-py",
133-
long_description=long_description,
134-
long_description_content_type="text/x-rst",
135-
version=package_version,
136-
author="Elastic Client Library Maintainers",
137-
author_email="[email protected]",
138-
project_urls={
139-
"Documentation": "https://elasticsearch-py.readthedocs.io",
140-
"Source Code": "https://github.com/elastic/elasticsearch-py",
141-
"Issue Tracker": "https://github.com/elastic/elasticsearch-py/issues",
142-
},
143-
packages=packages,
144-
package_data={"elasticsearch": ["py.typed", "*.pyi"]},
145-
include_package_data=True,
146-
zip_safe=False,
147-
classifiers=[
148-
"Development Status :: 5 - Production/Stable",
149-
"License :: OSI Approved :: Apache Software License",
150-
"Intended Audience :: Developers",
151-
"Operating System :: OS Independent",
152-
"Programming Language :: Python",
153-
"Programming Language :: Python :: 3",
154-
"Programming Language :: Python :: 3.7",
155-
"Programming Language :: Python :: 3.8",
156-
"Programming Language :: Python :: 3.9",
157-
"Programming Language :: Python :: 3.10",
158-
"Programming Language :: Python :: 3.11",
159-
"Programming Language :: Python :: 3.12",
160-
"Programming Language :: Python :: Implementation :: CPython",
161-
"Programming Language :: Python :: Implementation :: PyPy",
162-
],
163-
python_requires=">=3.7",
164-
install_requires=["elastic-transport>=8.13,<9"],
165-
extras_require={
166-
"requests": ["requests>=2.4.0, <3.0.0"],
167-
"async": ["aiohttp>=3,<4"],
168-
"orjson": ["orjson>=3"],
169-
},
170-
)

test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import os
19-
from typing import Any, AsyncIterator, Dict, List, Optional
18+
from typing import Any, Dict, List
2019

2120
from elastic_transport import AsyncTransport
2221

23-
from elasticsearch import AsyncElasticsearch
2422
from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService
2523

2624

@@ -87,69 +85,3 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
8785
async def perform_request(self, *args, **kwargs): # type: ignore
8886
self.requests.append(kwargs)
8987
return await super().perform_request(*args, **kwargs)
90-
91-
92-
def create_es_client(
93-
es_params: Optional[Dict[str, str]] = None, es_kwargs: Dict = {}
94-
) -> AsyncElasticsearch:
95-
if es_params is None:
96-
es_params = read_env()
97-
if not es_kwargs:
98-
es_kwargs = {}
99-
100-
if "es_cloud_id" in es_params:
101-
return AsyncElasticsearch(
102-
cloud_id=es_params["es_cloud_id"],
103-
api_key=es_params["es_api_key"],
104-
**es_kwargs,
105-
)
106-
return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs)
107-
108-
109-
def create_requests_saving_client() -> AsyncElasticsearch:
110-
return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport})
111-
112-
113-
async def es_client_fixture() -> AsyncIterator[AsyncElasticsearch]:
114-
params = read_env()
115-
client = create_es_client(params)
116-
117-
yield client
118-
119-
# clear indices
120-
await clear_test_indices(client)
121-
122-
# clear all test pipelines
123-
try:
124-
response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding")
125-
126-
for pipeline_id, _ in response.items():
127-
try:
128-
await client.ingest.delete_pipeline(id=pipeline_id)
129-
print(f"Deleted pipeline: {pipeline_id}") # noqa: T201
130-
except Exception as e:
131-
print(f"Pipeline error: {e}") # noqa: T201
132-
133-
except Exception:
134-
pass
135-
finally:
136-
await client.close()
137-
138-
139-
async def clear_test_indices(client: AsyncElasticsearch) -> None:
140-
response = await client.indices.get(index="_all")
141-
index_names = response.keys()
142-
for index_name in index_names:
143-
if index_name.startswith("test_"):
144-
await client.indices.delete(index=index_name)
145-
await client.indices.refresh(index="_all")
146-
147-
148-
def read_env() -> Dict:
149-
url = os.environ.get("ES_URL", "http://localhost:9200")
150-
cloud_id = os.environ.get("ES_CLOUD_ID")
151-
api_key = os.environ.get("ES_API_KEY")
152-
153-
if cloud_id:
154-
return {"es_cloud_id": cloud_id, "es_api_key": api_key}
155-
return {"es_url": url}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import os
19+
import uuid
20+
from typing import AsyncIterator, Dict
21+
22+
import pytest
23+
import pytest_asyncio
24+
25+
from elasticsearch import AsyncElasticsearch
26+
27+
from ._test_utils import AsyncRequestSavingTransport
28+
29+
30+
@pytest_asyncio.fixture
31+
async def es_client(elasticsearch_url: str) -> AsyncIterator[AsyncElasticsearch]:
32+
client = _create_es_client(elasticsearch_url)
33+
34+
yield client
35+
36+
# clear indices
37+
await _clear_test_indices(client)
38+
39+
# clear all test pipelines
40+
try:
41+
response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding")
42+
43+
for pipeline_id, _ in response.items():
44+
try:
45+
await client.ingest.delete_pipeline(id=pipeline_id)
46+
print(f"Deleted pipeline: {pipeline_id}") # noqa: T201
47+
except Exception as e:
48+
print(f"Pipeline error: {e}") # noqa: T201
49+
50+
except Exception:
51+
pass
52+
finally:
53+
await client.close()
54+
55+
56+
@pytest_asyncio.fixture
57+
async def requests_saving_client(
58+
elasticsearch_url: str,
59+
) -> AsyncIterator[AsyncElasticsearch]:
60+
client = _create_es_client(
61+
elasticsearch_url, es_kwargs={"transport_class": AsyncRequestSavingTransport}
62+
)
63+
64+
try:
65+
yield client
66+
finally:
67+
await client.close()
68+
69+
70+
@pytest.fixture(scope="function")
71+
def index_name() -> str:
72+
return f"test_{uuid.uuid4().hex}"
73+
74+
75+
async def _clear_test_indices(client: AsyncElasticsearch) -> None:
76+
response = await client.indices.get(index="_all")
77+
index_names = response.keys()
78+
for index_name in index_names:
79+
if index_name.startswith("test_"):
80+
await client.indices.delete(index=index_name)
81+
await client.indices.refresh(index="_all")
82+
83+
84+
def _create_es_client(
85+
elasticsearch_url: str, es_kwargs: Dict = {}
86+
) -> AsyncElasticsearch:
87+
if not elasticsearch_url:
88+
elasticsearch_url = os.environ.get("ES_URL", "http://localhost:9200")
89+
cloud_id = os.environ.get("ES_CLOUD_ID")
90+
api_key = os.environ.get("ES_API_KEY")
91+
92+
if cloud_id:
93+
es_params = {"es_cloud_id": cloud_id, "es_api_key": api_key}
94+
else:
95+
es_params = {"es_url": elasticsearch_url}
96+
97+
if "es_cloud_id" in es_params:
98+
return AsyncElasticsearch(
99+
cloud_id=es_params["es_cloud_id"],
100+
api_key=es_params["es_api_key"],
101+
**es_kwargs,
102+
)
103+
return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs)

test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,21 @@
1616
# under the License.
1717

1818
import os
19-
from typing import AsyncIterator
2019

2120
import pytest
22-
import pytest_asyncio
2321

2422
from elasticsearch import AsyncElasticsearch
2523
from elasticsearch.vectorstore._async._utils import model_is_deployed
2624
from elasticsearch.vectorstore._async.embedding_service import (
2725
AsyncElasticsearchEmbeddings,
2826
)
2927

30-
from ._test_utils import es_client_fixture
31-
3228
# deployed with
3329
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
3430
MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3")
3531
NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384"))
3632

3733

38-
@pytest_asyncio.fixture
39-
async def es_client() -> AsyncIterator[AsyncElasticsearch]:
40-
async for x in es_client_fixture():
41-
yield x
42-
43-
4434
@pytest.mark.asyncio
4535
async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) -> None:
4636
"""Test Elasticsearch embedding documents."""

0 commit comments

Comments
 (0)