Skip to content

Commit 5ba4c4c

Browse files
committed
add getting a list of all labels
1 parent dbb4f01 commit 5ba4c4c

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

python_bindings/bindings.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class Index {
216216
}
217217
}
218218

219-
std::vector<std::vector<data_t>> GetDataReturnList(py::object ids_ = py::none()) {
219+
std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
220220
std::vector<size_t> ids;
221221
if (!ids_.is_none()) {
222222
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
@@ -235,6 +235,16 @@ class Index {
235235
return data;
236236
}
237237

238+
std::vector<unsigned int> getIdsList() {
239+
240+
std::vector<unsigned int> ids;
241+
242+
for(auto kv : appr_alg->label_lookup_) {
243+
ids.push_back(kv.first);
244+
}
245+
return ids;
246+
}
247+
238248
py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
239249

240250
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
@@ -360,7 +370,8 @@ PYBIND11_PLUGIN(hnswlib) {
360370
py::arg("ef_construction")=200, py::arg("random_seed")=100)
361371
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1)
362372
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1)
363-
.def("get_items", &Index<float, float>::GetDataReturnList, py::arg("ids") = py::none())
373+
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
374+
.def("get_ids_list", &Index<float>::getIdsList)
364375
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
365376
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))
366377
.def("save_index", &Index<float>::saveIndex, py::arg("path_to_index"))

python_bindings/tests/bindings_test_labels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def testRandomSelf(self):
8181
diff_with_gt_labels=np.max(np.abs(data-items))
8282
self.assertAlmostEqual(diff_with_gt_labels,0,1e-4)
8383

84+
# Checking that all labels are returned correcly:
85+
sorted_labels=sorted(p.get_ids_list())
86+
self.assertEqual(np.sum(~np.asarray(sorted_labels)==np.asarray(range(num_elements))),0)
8487

8588

8689

0 commit comments

Comments
 (0)