Skip to content

Commit 6aac477

Browse files
authored
Add multithread search to BF index (#425)
* Add multithread search for BF index
1 parent dccd4f9 commit 6aac477

File tree

2 files changed

+87
-7
lines changed

2 files changed

+87
-7
lines changed

python_bindings/bindings.cpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ class BFIndex {
719719
int dim;
720720
bool index_inited;
721721
bool normalize;
722+
int num_threads_default;
722723

723724
hnswlib::labeltype cur_l;
724725
hnswlib::BruteforceSearch<dist_t>* alg;
@@ -739,6 +740,8 @@ class BFIndex {
739740
}
740741
alg = NULL;
741742
index_inited = false;
743+
744+
num_threads_default = std::thread::hardware_concurrency();
742745
}
743746

744747

@@ -749,6 +752,21 @@ class BFIndex {
749752
}
750753

751754

755+
size_t getMaxElements() const {
756+
return alg->maxelements_;
757+
}
758+
759+
760+
size_t getCurrentCount() const {
761+
return alg->cur_element_count;
762+
}
763+
764+
765+
void set_num_threads(int num_threads) {
766+
this->num_threads_default = num_threads;
767+
}
768+
769+
752770
void init_new_index(const size_t maxElements) {
753771
if (alg) {
754772
throw std::runtime_error("The index is already initiated.");
@@ -820,15 +838,19 @@ class BFIndex {
820838
py::object knnQuery_return_numpy(
821839
py::object input,
822840
size_t k = 1,
841+
int num_threads = -1,
823842
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
824843
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
825844
auto buffer = items.request();
826845
hnswlib::labeltype *data_numpy_l;
827846
dist_t *data_numpy_d;
828847
size_t rows, features;
848+
849+
if (num_threads <= 0)
850+
num_threads = num_threads_default;
851+
829852
{
830853
py::gil_scoped_release l;
831-
832854
get_input_array_shapes(buffer, &rows, &features);
833855

834856
data_numpy_l = new hnswlib::labeltype[rows * k];
@@ -837,16 +859,16 @@ class BFIndex {
837859
CustomFilterFunctor idFilter(filter);
838860
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;
839861

840-
for (size_t row = 0; row < rows; row++) {
862+
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
841863
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
842-
(void *) items.data(row), k, p_idFilter);
864+
(void*)items.data(row), k, p_idFilter);
843865
for (int i = k - 1; i >= 0; i--) {
844-
auto &result_tuple = result.top();
866+
auto& result_tuple = result.top();
845867
data_numpy_d[row * k + i] = result_tuple.first;
846868
data_numpy_l[row * k + i] = result_tuple.second;
847869
result.pop();
848870
}
849-
}
871+
});
850872
}
851873

852874
py::capsule free_when_done_l(data_numpy_l, [](void *f) {
@@ -957,13 +979,22 @@ PYBIND11_PLUGIN(hnswlib) {
957979
py::class_<BFIndex<float>>(m, "BFIndex")
958980
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
959981
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
960-
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none())
982+
.def("knn_query",
983+
&BFIndex<float>::knnQuery_return_numpy,
984+
py::arg("data"),
985+
py::arg("k") = 1,
986+
py::arg("num_threads") = -1,
987+
py::arg("filter") = py::none())
961988
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
962989
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
990+
.def("set_num_threads", &BFIndex<float>::set_num_threads, py::arg("num_threads"))
963991
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
964992
.def("load_index", &BFIndex<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0)
965993
.def("__repr__", [](const BFIndex<float> &a) {
966994
return "<hnswlib.BFIndex(space='" + a.space_name + "', dim="+std::to_string(a.dim)+")>";
967-
});
995+
})
996+
.def("get_max_elements", &BFIndex<float>::getMaxElements)
997+
.def("get_current_count", &BFIndex<float>::getCurrentCount)
998+
.def_readwrite("num_threads", &BFIndex<float>::num_threads_default);
968999
return m.ptr();
9691000
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
import hnswlib
6+
7+
8+
class RandomSelfTestCase(unittest.TestCase):
9+
def testBFIndex(self):
10+
11+
dim = 16
12+
num_elements = 10000
13+
num_queries = 1000
14+
k = 20
15+
16+
# Generating sample data
17+
data = np.float32(np.random.random((num_elements, dim)))
18+
19+
# Declaring index
20+
bf_index = hnswlib.BFIndex(space='l2', dim=dim) # possible options are l2, cosine or ip
21+
bf_index.init_index(max_elements=num_elements)
22+
23+
num_threads = 8
24+
bf_index.set_num_threads(num_threads) # by default using all available cores
25+
26+
print(f"Adding all elements {num_elements}")
27+
bf_index.add_items(data)
28+
29+
self.assertEqual(bf_index.num_threads, num_threads)
30+
self.assertEqual(bf_index.get_max_elements(), num_elements)
31+
self.assertEqual(bf_index.get_current_count(), num_elements)
32+
33+
queries = np.float32(np.random.random((num_queries, dim)))
34+
print("Searching nearest neighbours")
35+
labels, distances = bf_index.knn_query(queries, k=k)
36+
37+
print("Checking results")
38+
for i in range(num_queries):
39+
query = queries[i]
40+
sq_dists = (data - query)**2
41+
dists = np.sum(sq_dists, axis=1)
42+
labels_gt = np.argsort(dists)[:k]
43+
dists_gt = dists[labels_gt]
44+
dists_bf = distances[i]
45+
# we can compare labels but because of numeric errors in distance calculation in C++ and numpy
46+
# sometimes we get different order of labels, therefore we compare distances
47+
max_diff_with_gt = np.max(np.abs(dists_gt - dists_bf))
48+
49+
self.assertTrue(max_diff_with_gt < 1e-5)

0 commit comments

Comments
 (0)