@@ -26,6 +26,10 @@ namespace hnswlib {
26
26
loadIndex (location, s, max_elements);
27
27
}
28
28
29
+ HierarchicalNSW (SpaceInterface<dist_t > *s, std::istream & input, bool nmslib = false , size_t max_elements=0 ) {
30
+ loadIndexFromStream (input, s, max_elements);
31
+ }
32
+
29
33
HierarchicalNSW (SpaceInterface<dist_t > *s, size_t max_elements, size_t M = 16 , size_t ef_construction = 200 , size_t random_seed = 100 ) :
30
34
link_list_locks_ (max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
31
35
max_elements_ = max_elements;
@@ -57,8 +61,6 @@ namespace hnswlib {
57
61
58
62
visited_list_pool_ = new VisitedListPool (1 , max_elements);
59
63
60
-
61
-
62
64
// initializations for special treatment of the first node
63
65
enterpoint_node_ = -1 ;
64
66
maxlevel_ = -1 ;
@@ -102,6 +104,8 @@ namespace hnswlib {
102
104
double mult_, revSize_;
103
105
int maxlevel_;
104
106
107
+ std::mutex global;
108
+ size_t ef_;
105
109
106
110
VisitedListPool *visited_list_pool_;
107
111
std::mutex cur_element_count_guard_;
@@ -511,8 +515,6 @@ namespace hnswlib {
511
515
return next_closest_entry_point;
512
516
}
513
517
514
- std::mutex global;
515
- size_t ef_;
516
518
517
519
void setEf (size_t ef) {
518
520
ef_ = ef;
@@ -598,10 +600,7 @@ namespace hnswlib {
598
600
max_elements_=new_max_elements;
599
601
600
602
}
601
-
602
- void saveIndex (const std::string &location) {
603
- std::ofstream output (location, std::ios::binary);
604
- std::streampos position;
603
+ void saveIndexToStream (std::ostream &output) {
605
604
606
605
writeBinaryPOD (output, offsetLevel0_);
607
606
writeBinaryPOD (output, max_elements_);
@@ -626,17 +625,17 @@ namespace hnswlib {
626
625
if (linkListSize)
627
626
output.write (linkLists_[i], linkListSize);
628
627
}
629
- output.close ();
630
- }
631
-
632
- void loadIndex (const std::string &location, SpaceInterface<dist_t > *s, size_t max_elements_i=0 ) {
633
628
629
+ }
634
630
635
- std::ifstream input (location, std::ios::binary);
636
-
637
- if (!input.is_open ())
638
- throw std::runtime_error (" Cannot open file" );
631
+ void saveIndex (const std::string &location) {
632
+ std::ofstream output (location, std::ios::binary);
633
+ std::streampos position;
634
+ saveIndexToStream (output);
635
+ output.close ();
636
+ }
639
637
638
+ void loadIndexFromStream (std::istream & input, SpaceInterface<dist_t > *s, size_t max_elements_i=0 ) {
640
639
641
640
// get file size:
642
641
input.seekg (0 ,input.end );
@@ -663,14 +662,12 @@ namespace hnswlib {
663
662
readBinaryPOD (input, mult_);
664
663
readBinaryPOD (input, ef_construction_);
665
664
666
-
667
665
data_size_ = s->get_data_size ();
668
666
fstdistfunc_ = s->get_dist_func ();
669
667
dist_func_param_ = s->get_dist_func_param ();
670
668
671
669
auto pos=input.tellg ();
672
670
673
-
674
671
// / Optional - check if index is ok:
675
672
676
673
input.seekg (cur_element_count * size_data_per_element_,input.cur );
@@ -696,15 +693,11 @@ namespace hnswlib {
696
693
697
694
input.seekg (pos,input.beg );
698
695
699
-
700
696
data_level0_memory_ = (char *) malloc (max_elements * size_data_per_element_);
701
697
if (data_level0_memory_ == nullptr )
702
698
throw std::runtime_error (" Not enough memory: loadIndex failed to allocate level0" );
703
699
input.read (data_level0_memory_, cur_element_count * size_data_per_element_);
704
700
705
-
706
-
707
-
708
701
size_links_per_element_ = maxM_ * sizeof (tableint) + sizeof (linklistsizeint);
709
702
710
703
@@ -715,7 +708,6 @@ namespace hnswlib {
715
708
716
709
visited_list_pool_ = new VisitedListPool (1 , max_elements);
717
710
718
-
719
711
linkLists_ = (char **) malloc (sizeof (void *) * max_elements);
720
712
if (linkLists_ == nullptr )
721
713
throw std::runtime_error (" Not enough memory: loadIndex failed to allocate linklists" );
@@ -746,11 +738,22 @@ namespace hnswlib {
746
738
has_deletions_=true ;
747
739
}
748
740
749
- input.close ();
750
741
751
742
return ;
752
743
}
753
744
745
+
746
+
747
+ void loadIndex (const std::string &location, SpaceInterface<dist_t > *s, size_t max_elements_i=0 ) {
748
+ std::ifstream input (location, std::ios::binary);
749
+ if (!input.is_open ())
750
+ throw std::runtime_error (" Cannot open file" );
751
+
752
+ loadIndexFromStream (input, s, max_elements_i);
753
+ input.close ();
754
+ return ;
755
+ }
756
+
754
757
template <typename data_t >
755
758
std::vector<data_t > getDataByLabel (labeltype label)
756
759
{
@@ -874,7 +877,7 @@ namespace hnswlib {
874
877
for (auto && cand : sCand ) {
875
878
if (cand == neigh)
876
879
continue ;
877
-
880
+
878
881
dist_t distance = fstdistfunc_ (getDataByInternalId (neigh), getDataByInternalId (cand), dist_func_param_);
879
882
if (candidates.size () < elementsToKeep) {
880
883
candidates.emplace (distance, cand);
@@ -1137,7 +1140,7 @@ namespace hnswlib {
1137
1140
}
1138
1141
1139
1142
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates;
1140
- if (has_deletions_) {
1143
+ if (has_deletions_) {
1141
1144
top_candidates=searchBaseLayerST<true ,true >(
1142
1145
currObj, query_data, std::max (ef_, k));
1143
1146
}
@@ -1186,27 +1189,27 @@ namespace hnswlib {
1186
1189
std::unordered_set<tableint> s;
1187
1190
for (int j=0 ; j<size; j++){
1188
1191
assert (data[j] > 0 );
1189
- assert (data[j] < cur_element_count);
1192
+ assert (data[j] < cur_element_count);
1190
1193
assert (data[j] != i);
1191
1194
inbound_connections_num[data[j]]++;
1192
1195
s.insert (data[j]);
1193
1196
connections_checked++;
1194
-
1197
+
1195
1198
}
1196
1199
assert (s.size () == size);
1197
1200
}
1198
1201
}
1199
1202
if (cur_element_count > 1 ){
1200
1203
int min1=inbound_connections_num[0 ], max1=inbound_connections_num[0 ];
1201
- for (int i=0 ; i < cur_element_count; i++){
1204
+ for (int i=0 ; i < cur_element_count; i++){
1202
1205
assert (inbound_connections_num[i] > 0 );
1203
1206
min1=std::min (inbound_connections_num[i],min1);
1204
1207
max1=std::max (inbound_connections_num[i],max1);
1205
1208
}
1206
1209
std::cout << " Min inbound: " << min1 << " , Max inbound:" << max1 << " \n " ;
1207
1210
}
1208
1211
std::cout << " integrity ok, checked " << connections_checked << " connections\n " ;
1209
-
1212
+
1210
1213
}
1211
1214
1212
1215
};
0 commit comments