Skip to content

Commit cb7b398

Browse files
committed
New methods loadIndexFromStream and saveIndexToStream expose de-/serialization logic of HierarchicalNSW class via std::i/ostream.
1 parent 21b54fe commit cb7b398

File tree

1 file changed

+33
-30
lines changed

1 file changed

+33
-30
lines changed

hnswlib/hnswalg.h

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ namespace hnswlib {
2626
loadIndex(location, s, max_elements);
2727
}
2828

29+
HierarchicalNSW(SpaceInterface<dist_t> *s, std::istream & input, bool nmslib = false, size_t max_elements=0) {
30+
loadIndexFromStream(input, s, max_elements);
31+
}
32+
2933
HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
3034
link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
3135
max_elements_ = max_elements;
@@ -57,8 +61,6 @@ namespace hnswlib {
5761

5862
visited_list_pool_ = new VisitedListPool(1, max_elements);
5963

60-
61-
6264
//initializations for special treatment of the first node
6365
enterpoint_node_ = -1;
6466
maxlevel_ = -1;
@@ -102,6 +104,8 @@ namespace hnswlib {
102104
double mult_, revSize_;
103105
int maxlevel_;
104106

107+
std::mutex global;
108+
size_t ef_;
105109

106110
VisitedListPool *visited_list_pool_;
107111
std::mutex cur_element_count_guard_;
@@ -511,8 +515,6 @@ namespace hnswlib {
511515
return next_closest_entry_point;
512516
}
513517

514-
std::mutex global;
515-
size_t ef_;
516518

517519
void setEf(size_t ef) {
518520
ef_ = ef;
@@ -598,10 +600,7 @@ namespace hnswlib {
598600
max_elements_=new_max_elements;
599601

600602
}
601-
602-
void saveIndex(const std::string &location) {
603-
std::ofstream output(location, std::ios::binary);
604-
std::streampos position;
603+
void saveIndexToStream(std::ostream &output) {
605604

606605
writeBinaryPOD(output, offsetLevel0_);
607606
writeBinaryPOD(output, max_elements_);
@@ -626,17 +625,17 @@ namespace hnswlib {
626625
if (linkListSize)
627626
output.write(linkLists_[i], linkListSize);
628627
}
629-
output.close();
630-
}
631-
632-
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
633628

629+
}
634630

635-
std::ifstream input(location, std::ios::binary);
636-
637-
if (!input.is_open())
638-
throw std::runtime_error("Cannot open file");
631+
void saveIndex(const std::string &location) {
632+
std::ofstream output(location, std::ios::binary);
633+
std::streampos position;
634+
saveIndexToStream(output);
635+
output.close();
636+
}
639637

638+
void loadIndexFromStream(std::istream & input, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
640639

641640
// get file size:
642641
input.seekg(0,input.end);
@@ -663,14 +662,12 @@ namespace hnswlib {
663662
readBinaryPOD(input, mult_);
664663
readBinaryPOD(input, ef_construction_);
665664

666-
667665
data_size_ = s->get_data_size();
668666
fstdistfunc_ = s->get_dist_func();
669667
dist_func_param_ = s->get_dist_func_param();
670668

671669
auto pos=input.tellg();
672670

673-
674671
/// Optional - check if index is ok:
675672

676673
input.seekg(cur_element_count * size_data_per_element_,input.cur);
@@ -696,15 +693,11 @@ namespace hnswlib {
696693

697694
input.seekg(pos,input.beg);
698695

699-
700696
data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
701697
if (data_level0_memory_ == nullptr)
702698
throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
703699
input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
704700

705-
706-
707-
708701
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
709702

710703

@@ -715,7 +708,6 @@ namespace hnswlib {
715708

716709
visited_list_pool_ = new VisitedListPool(1, max_elements);
717710

718-
719711
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
720712
if (linkLists_ == nullptr)
721713
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
@@ -746,11 +738,22 @@ namespace hnswlib {
746738
has_deletions_=true;
747739
}
748740

749-
input.close();
750741

751742
return;
752743
}
753744

745+
746+
747+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
748+
std::ifstream input(location, std::ios::binary);
749+
if (!input.is_open())
750+
throw std::runtime_error("Cannot open file");
751+
752+
loadIndexFromStream(input, s, max_elements_i);
753+
input.close();
754+
return;
755+
}
756+
754757
template<typename data_t>
755758
std::vector<data_t> getDataByLabel(labeltype label)
756759
{
@@ -874,7 +877,7 @@ namespace hnswlib {
874877
for (auto&& cand : sCand) {
875878
if (cand == neigh)
876879
continue;
877-
880+
878881
dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
879882
if (candidates.size() < elementsToKeep) {
880883
candidates.emplace(distance, cand);
@@ -1137,7 +1140,7 @@ namespace hnswlib {
11371140
}
11381141

11391142
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1140-
if (has_deletions_) {
1143+
if (has_deletions_) {
11411144
top_candidates=searchBaseLayerST<true,true>(
11421145
currObj, query_data, std::max(ef_, k));
11431146
}
@@ -1186,27 +1189,27 @@ namespace hnswlib {
11861189
std::unordered_set<tableint> s;
11871190
for (int j=0; j<size; j++){
11881191
assert(data[j] > 0);
1189-
assert(data[j] < cur_element_count);
1192+
assert(data[j] < cur_element_count);
11901193
assert (data[j] != i);
11911194
inbound_connections_num[data[j]]++;
11921195
s.insert(data[j]);
11931196
connections_checked++;
1194-
1197+
11951198
}
11961199
assert(s.size() == size);
11971200
}
11981201
}
11991202
if(cur_element_count > 1){
12001203
int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1201-
for(int i=0; i < cur_element_count; i++){
1204+
for(int i=0; i < cur_element_count; i++){
12021205
assert(inbound_connections_num[i] > 0);
12031206
min1=std::min(inbound_connections_num[i],min1);
12041207
max1=std::max(inbound_connections_num[i],max1);
12051208
}
12061209
std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
12071210
}
12081211
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1209-
1212+
12101213
}
12111214

12121215
};

0 commit comments

Comments
 (0)