Skip to content

Commit 79c4ba4

Browse files
Sean UnderwoodYury Malkov
Sean Underwood
authored and
Yury Malkov
committed
Raising an error when trying to get data for a non-existent label
1 parent 79c74b4 commit 79c4ba4

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

hnswlib/hnswalg.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,13 @@ namespace hnswlib {
633633
template<typename data_t>
634634
std::vector<data_t> getDataByLabel(labeltype label)
635635
{
636-
tableint label_c = label_lookup_[label];
636+
tableint label_c;
637+
auto search = label_lookup_.find(label);
638+
if (search == label_lookup_.end()) {
639+
throw std::runtime_error("Label not found");
640+
}
641+
label_c = search->second;
642+
637643
char* data_ptrv = getDataByInternalId(label_c);
638644
size_t dim = *((size_t *) dist_func_param_);
639645
std::vector<data_t> data;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
3+
4+
class RandomSelfTestCase(unittest.TestCase):
5+
def testGettingItems(self):
6+
import hnswlib
7+
import numpy as np
8+
9+
dim = 16
10+
num_elements = 10000
11+
12+
# Generating sample data
13+
data = np.float32(np.random.random((num_elements, dim)))
14+
labels = np.arange(0, num_elements)
15+
16+
# Declaring index
17+
p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip
18+
19+
# Initing index
20+
# max_elements - the maximum number of elements, should be known beforehand
21+
# (probably will be made optional in the future)
22+
#
23+
# ef_construction - controls index search speed/build speed tradeoff
24+
# M - is tightly connected with internal dimensionality of the data
25+
# stronlgy affects the memory consumption
26+
27+
p.init_index(max_elements=num_elements, ef_construction=100, M=16)
28+
29+
# Controlling the recall by setting ef:
30+
# higher ef leads to better accuracy, but slower search
31+
p.set_ef(300)
32+
33+
p.set_num_threads(4) # by default using all available cores
34+
35+
# Before adding anything, getting any labels should fail
36+
self.assertRaises(Exception, lambda: p.get_items(labels))
37+
38+
print("Adding all elements (%d)" % (len(data)))
39+
p.add_items(data, labels)
40+
41+
# After adding them, all labels should be retrievable
42+
returned_items = p.get_items(labels)
43+
self.assertSequenceEqual(data.tolist(), returned_items)
44+
45+
if __name__ == "__main__":
46+
unittest.main()

0 commit comments

Comments
 (0)