Skip to content

Commit 4cda418

Browse files
Vector search example
1 parent c650b62 commit 4cda418

File tree

2 files changed

+339
-0
lines changed

2 files changed

+339
-0
lines changed

examples/async/vectors.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
# Vector database example
20+
21+
Requirements:
22+
23+
$ pip install nltk sentence_transformers tqdm elasticsearch_dsl
24+
25+
To run the example:
26+
27+
$ python vectors.py "text to search"
28+
29+
The index will be created automatically if it does not exist. Add `--create` to
30+
regenerate it.
31+
32+
The example dataset includes a selection of workplace documentation. The
33+
following are good example queries to try out:
34+
35+
$ python vectors.py "work from home"
36+
$ python vectors.py "vacation time"
37+
$ python vectors.py "bring a bird to work"
38+
"""
39+
40+
import argparse
41+
import asyncio
42+
import json
43+
import os
44+
from urllib.request import urlopen
45+
46+
import nltk
47+
from sentence_transformers import SentenceTransformer
48+
from tqdm import tqdm
49+
50+
from elasticsearch_dsl import (
51+
AsyncDocument,
52+
Date,
53+
DenseVector,
54+
InnerDoc,
55+
Keyword,
56+
Nested,
57+
Text,
58+
async_connections,
59+
)
60+
61+
DATASET_URL = "https://github.com/raw/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
62+
MODEL_NAME = "all-MiniLM-L6-v2"
63+
64+
# initialize sentence tokenizer
65+
nltk.download("punkt", quiet=True)
66+
67+
68+
class Passage(InnerDoc):
69+
content = Text()
70+
embedding = DenseVector()
71+
72+
73+
class WorkplaceDoc(AsyncDocument):
74+
class Index:
75+
name = "workplace_documents"
76+
77+
name = Text()
78+
summary = Text()
79+
content = Text()
80+
created = Date()
81+
updated = Date()
82+
url = Keyword()
83+
category = Keyword()
84+
passages = Nested(Passage)
85+
86+
_model = None
87+
88+
@classmethod
89+
def get_embedding_model(cls):
90+
if cls._model is None:
91+
cls._model = SentenceTransformer(MODEL_NAME)
92+
return cls._model
93+
94+
def clean(self):
95+
# split the content into sentences
96+
passages = nltk.sent_tokenize(self.content)
97+
98+
# generate an embedding for each passage and save it as a nested document
99+
model = self.get_embedding_model()
100+
for passage in passages:
101+
self.passages.append(
102+
Passage(content=passage, embedding=list(model.encode(passage)))
103+
)
104+
105+
106+
async def create():
107+
108+
# create the index
109+
await WorkplaceDoc._index.delete(ignore_unavailable=True)
110+
await WorkplaceDoc.init()
111+
112+
# download the data
113+
dataset = json.loads(urlopen(DATASET_URL).read())
114+
115+
# import the dataset
116+
for data in tqdm(dataset, desc="Indexing documents..."):
117+
doc = WorkplaceDoc(
118+
name=data["name"],
119+
summary=data["summary"],
120+
content=data["content"],
121+
created=data.get("created_on"),
122+
updated=data.get("updated_at"),
123+
url=data["url"],
124+
category=data["category"],
125+
)
126+
await doc.save()
127+
128+
129+
async def search(query):
130+
model = WorkplaceDoc.get_embedding_model()
131+
search = WorkplaceDoc.search().knn(
132+
field="passages.embedding",
133+
k=5,
134+
num_candidates=50,
135+
query_vector=list(model.encode(query)),
136+
inner_hits={"size": 3},
137+
)
138+
async for hit in search:
139+
print(f"Document: {hit.name} (Category: {hit.category}")
140+
for passage in hit.meta.inner_hits.passages:
141+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
142+
print("")
143+
144+
145+
def parse_args():
146+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
147+
parser.add_argument(
148+
"--create", action="store_true", help="Create and populate a new index"
149+
)
150+
parser.add_argument("query", action="store", help="The search query")
151+
return parser.parse_args()
152+
153+
154+
async def main():
155+
args = parse_args()
156+
157+
# initiate the default connection to elasticsearch
158+
async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
159+
160+
if args.create or not await WorkplaceDoc._index.exists():
161+
await create()
162+
163+
await search(args.query)
164+
165+
# close the connection
166+
await async_connections.get_connection().close()
167+
168+
169+
if __name__ == "__main__":
170+
asyncio.run(main())

examples/vectors.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
# Vector database example
20+
21+
Requirements:
22+
23+
$ pip install nltk sentence_transformers tqdm elasticsearch_dsl
24+
25+
To run the example:
26+
27+
$ python vectors.py "text to search"
28+
29+
The index will be created automatically if it does not exist. Add `--create` to
30+
regenerate it.
31+
32+
The example dataset includes a selection of workplace documentation. The
33+
following are good example queries to try out:
34+
35+
$ python vectors.py "work from home"
36+
$ python vectors.py "vacation time"
37+
$ python vectors.py "bring a bird to work"
38+
"""
39+
40+
import argparse
41+
import json
42+
import os
43+
from urllib.request import urlopen
44+
45+
import nltk
46+
from sentence_transformers import SentenceTransformer
47+
from tqdm import tqdm
48+
49+
from elasticsearch_dsl import (
50+
Date,
51+
DenseVector,
52+
Document,
53+
InnerDoc,
54+
Keyword,
55+
Nested,
56+
Text,
57+
connections,
58+
)
59+
60+
DATASET_URL = "https://github.com/raw/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
61+
MODEL_NAME = "all-MiniLM-L6-v2"
62+
63+
# initialize sentence tokenizer
64+
nltk.download("punkt", quiet=True)
65+
66+
67+
class Passage(InnerDoc):
68+
content = Text()
69+
embedding = DenseVector()
70+
71+
72+
class WorkplaceDoc(Document):
73+
class Index:
74+
name = "workplace_documents"
75+
76+
name = Text()
77+
summary = Text()
78+
content = Text()
79+
created = Date()
80+
updated = Date()
81+
url = Keyword()
82+
category = Keyword()
83+
passages = Nested(Passage)
84+
85+
_model = None
86+
87+
@classmethod
88+
def get_embedding_model(cls):
89+
if cls._model is None:
90+
cls._model = SentenceTransformer(MODEL_NAME)
91+
return cls._model
92+
93+
def clean(self):
94+
# split the content into sentences
95+
passages = nltk.sent_tokenize(self.content)
96+
97+
# generate an embedding for each passage and save it as a nested document
98+
model = self.get_embedding_model()
99+
for passage in passages:
100+
self.passages.append(
101+
Passage(content=passage, embedding=list(model.encode(passage)))
102+
)
103+
104+
105+
def create():
106+
107+
# create the index
108+
WorkplaceDoc._index.delete(ignore_unavailable=True)
109+
WorkplaceDoc.init()
110+
111+
# download the data
112+
dataset = json.loads(urlopen(DATASET_URL).read())
113+
114+
# import the dataset
115+
for data in tqdm(dataset, desc="Indexing documents..."):
116+
doc = WorkplaceDoc(
117+
name=data["name"],
118+
summary=data["summary"],
119+
content=data["content"],
120+
created=data.get("created_on"),
121+
updated=data.get("updated_at"),
122+
url=data["url"],
123+
category=data["category"],
124+
)
125+
doc.save()
126+
127+
128+
def search(query):
129+
model = WorkplaceDoc.get_embedding_model()
130+
search = WorkplaceDoc.search().knn(
131+
field="passages.embedding",
132+
k=5,
133+
num_candidates=50,
134+
query_vector=list(model.encode(query)),
135+
inner_hits={"size": 3},
136+
)
137+
for hit in search:
138+
print(f"Document: {hit.name} (Category: {hit.category}")
139+
for passage in hit.meta.inner_hits.passages:
140+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
141+
print("")
142+
143+
144+
def parse_args():
145+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
146+
parser.add_argument(
147+
"--create", action="store_true", help="Create and populate a new index"
148+
)
149+
parser.add_argument("query", action="store", help="The search query")
150+
return parser.parse_args()
151+
152+
153+
def main():
154+
args = parse_args()
155+
156+
# initiate the default connection to elasticsearch
157+
connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
158+
159+
if args.create or not WorkplaceDoc._index.exists():
160+
create()
161+
162+
search(args.query)
163+
164+
# close the connection
165+
connections.get_connection().close()
166+
167+
168+
if __name__ == "__main__":
169+
main()

0 commit comments

Comments
 (0)