Skip to content

Commit c1a21b6

Browse files
Vector search example
1 parent c650b62 commit c1a21b6

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

examples/async/vectors.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Vector database index creation example
20+
"""
21+
22+
import argparse
23+
import asyncio
24+
import json
25+
import os
26+
from urllib.request import urlopen
27+
28+
import nltk
29+
from sentence_transformers import SentenceTransformer
30+
from tqdm import tqdm
31+
32+
from elasticsearch_dsl import (
33+
AsyncDocument,
34+
Date,
35+
DenseVector,
36+
InnerDoc,
37+
Keyword,
38+
Nested,
39+
Text,
40+
async_connections,
41+
)
42+
43+
DATASET_URL = "https://github.com/raw/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
44+
MODEL_NAME = "all-MiniLM-L6-v2"
45+
46+
# initialize sentence tokenizer
47+
nltk.download("punkt", quiet=True)
48+
49+
50+
class Passage(InnerDoc):
51+
content = Text()
52+
embedding = DenseVector()
53+
54+
55+
class WorkplaceDoc(AsyncDocument):
56+
class Index:
57+
name = "workplace_documents"
58+
59+
name = Text()
60+
summary = Text()
61+
content = Text()
62+
created = Date()
63+
updated = Date()
64+
url = Keyword()
65+
category = Keyword()
66+
passages = Nested(Passage)
67+
68+
_model = None
69+
70+
@classmethod
71+
def get_embedding_model(cls):
72+
if cls._model is None:
73+
cls._model = SentenceTransformer(MODEL_NAME)
74+
return cls._model
75+
76+
def clean(self):
77+
# split the content into sentences
78+
passages = nltk.sent_tokenize(self.content)
79+
80+
# generate an embedding for each passage and save it as a nested document
81+
model = self.get_embedding_model()
82+
for passage in passages:
83+
self.passages.append(
84+
Passage(content=passage, embedding=list(model.encode(passage)))
85+
)
86+
87+
88+
async def create():
89+
90+
# create the index
91+
await WorkplaceDoc._index.delete(ignore_unavailable=True)
92+
await WorkplaceDoc.init()
93+
94+
# download the data
95+
dataset = json.loads(urlopen(DATASET_URL).read())
96+
97+
# import the dataset
98+
for data in tqdm(dataset, desc="Indexing documents..."):
99+
doc = WorkplaceDoc(
100+
name=data["name"],
101+
summary=data["summary"],
102+
content=data["content"],
103+
created=data.get("created_on"),
104+
updated=data.get("updated_at"),
105+
url=data["url"],
106+
category=data["category"],
107+
)
108+
await doc.save()
109+
110+
111+
async def search(query):
112+
model = WorkplaceDoc.get_embedding_model()
113+
search = WorkplaceDoc.search().knn(
114+
field="passages.embedding",
115+
k=5,
116+
num_candidates=50,
117+
query_vector=list(model.encode(query)),
118+
inner_hits={"size": 3},
119+
)
120+
async for hit in search:
121+
print(f"Document: {hit.name}")
122+
for passage in hit.meta.inner_hits.passages:
123+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
124+
print("")
125+
126+
127+
def parse_args():
128+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
129+
parser.add_argument(
130+
"--create", action="store_true", help="Create and populate a new index"
131+
)
132+
parser.add_argument("query", action="store", help="The search query")
133+
return parser.parse_args()
134+
135+
136+
async def main():
137+
args = parse_args()
138+
139+
# initiate the default connection to elasticsearch
140+
async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
141+
142+
if args.create or not await WorkplaceDoc._index.exists():
143+
await create()
144+
145+
await search(args.query)
146+
147+
# close the connection
148+
await async_connections.get_connection().close()
149+
150+
151+
if __name__ == "__main__":
152+
asyncio.run(main())

examples/vectors.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Vector database index creation example
20+
"""
21+
22+
import argparse
23+
import json
24+
import os
25+
from urllib.request import urlopen
26+
27+
import nltk
28+
from sentence_transformers import SentenceTransformer
29+
from tqdm import tqdm
30+
31+
from elasticsearch_dsl import (
32+
Date,
33+
DenseVector,
34+
Document,
35+
InnerDoc,
36+
Keyword,
37+
Nested,
38+
Text,
39+
connections,
40+
)
41+
42+
DATASET_URL = "https://github.com/raw/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
43+
MODEL_NAME = "all-MiniLM-L6-v2"
44+
45+
# initialize sentence tokenizer
46+
nltk.download("punkt", quiet=True)
47+
48+
49+
class Passage(InnerDoc):
50+
content = Text()
51+
embedding = DenseVector()
52+
53+
54+
class WorkplaceDoc(Document):
55+
class Index:
56+
name = "workplace_documents"
57+
58+
name = Text()
59+
summary = Text()
60+
content = Text()
61+
created = Date()
62+
updated = Date()
63+
url = Keyword()
64+
category = Keyword()
65+
passages = Nested(Passage)
66+
67+
_model = None
68+
69+
@classmethod
70+
def get_embedding_model(cls):
71+
if cls._model is None:
72+
cls._model = SentenceTransformer(MODEL_NAME)
73+
return cls._model
74+
75+
def clean(self):
76+
# split the content into sentences
77+
passages = nltk.sent_tokenize(self.content)
78+
79+
# generate an embedding for each passage and save it as a nested document
80+
model = self.get_embedding_model()
81+
for passage in passages:
82+
self.passages.append(
83+
Passage(content=passage, embedding=list(model.encode(passage)))
84+
)
85+
86+
87+
def create():
88+
89+
# create the index
90+
WorkplaceDoc._index.delete(ignore_unavailable=True)
91+
WorkplaceDoc.init()
92+
93+
# download the data
94+
dataset = json.loads(urlopen(DATASET_URL).read())
95+
96+
# import the dataset
97+
for data in tqdm(dataset, desc="Indexing documents..."):
98+
doc = WorkplaceDoc(
99+
name=data["name"],
100+
summary=data["summary"],
101+
content=data["content"],
102+
created=data.get("created_on"),
103+
updated=data.get("updated_at"),
104+
url=data["url"],
105+
category=data["category"],
106+
)
107+
doc.save()
108+
109+
110+
def search(query):
111+
model = WorkplaceDoc.get_embedding_model()
112+
search = WorkplaceDoc.search().knn(
113+
field="passages.embedding",
114+
k=5,
115+
num_candidates=50,
116+
query_vector=list(model.encode(query)),
117+
inner_hits={"size": 3},
118+
)
119+
for hit in search:
120+
print(f"Document: {hit.name}")
121+
for passage in hit.meta.inner_hits.passages:
122+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
123+
print("")
124+
125+
126+
def parse_args():
127+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
128+
parser.add_argument(
129+
"--create", action="store_true", help="Create and populate a new index"
130+
)
131+
parser.add_argument("query", action="store", help="The search query")
132+
return parser.parse_args()
133+
134+
135+
def main():
136+
args = parse_args()
137+
138+
# initiate the default connection to elasticsearch
139+
connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
140+
141+
if args.create or not WorkplaceDoc._index.exists():
142+
create()
143+
144+
search(args.query)
145+
146+
# close the connection
147+
connections.get_connection().close()
148+
149+
150+
if __name__ == "__main__":
151+
main()

0 commit comments

Comments
 (0)