1
+ #include " ../hnswlib/hnswlib.h"
2
+ #include < thread>
3
+ class StopW
4
+ {
5
+ std::chrono::steady_clock::time_point time_begin;
6
+
7
+ public:
8
+ StopW ()
9
+ {
10
+ time_begin = std::chrono::steady_clock::now ();
11
+ }
12
+
13
+ float getElapsedTimeMicro ()
14
+ {
15
+ std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now ();
16
+ return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count ());
17
+ }
18
+
19
+ void reset ()
20
+ {
21
+ time_begin = std::chrono::steady_clock::now ();
22
+ }
23
+ };
24
+
25
+ /*
26
+ * replacement for the openmp '#pragma omp parallel for' directive
27
+ * only handles a subset of functionality (no reductions etc)
28
+ * Process ids from start (inclusive) to end (EXCLUSIVE)
29
+ *
30
+ * The method is borrowed from nmslib
31
+ */
32
+ template <class Function >
33
+ inline void ParallelFor (size_t start, size_t end, size_t numThreads, Function fn) {
34
+ if (numThreads <= 0 ) {
35
+ numThreads = std::thread::hardware_concurrency ();
36
+ }
37
+
38
+ if (numThreads == 1 ) {
39
+ for (size_t id = start; id < end; id++) {
40
+ fn (id, 0 );
41
+ }
42
+ } else {
43
+ std::vector<std::thread> threads;
44
+ std::atomic<size_t > current (start);
45
+
46
+ // keep track of exceptions in threads
47
+ // https://stackoverflow.com/a/32428427/1713196
48
+ std::exception_ptr lastException = nullptr ;
49
+ std::mutex lastExceptMutex;
50
+
51
+ for (size_t threadId = 0 ; threadId < numThreads; ++threadId) {
52
+ threads.push_back (std::thread ([&, threadId] {
53
+ while (true ) {
54
+ size_t id = current.fetch_add (1 );
55
+
56
+ if ((id >= end)) {
57
+ break ;
58
+ }
59
+
60
+ try {
61
+ fn (id, threadId);
62
+ } catch (...) {
63
+ std::unique_lock<std::mutex> lastExcepLock (lastExceptMutex);
64
+ lastException = std::current_exception ();
65
+ /*
66
+ * This will work even when current is the largest value that
67
+ * size_t can fit, because fetch_add returns the previous value
68
+ * before the increment (what will result in overflow
69
+ * and produce 0 instead of current + 1).
70
+ */
71
+ current = end;
72
+ break ;
73
+ }
74
+ }
75
+ }));
76
+ }
77
+ for (auto &thread : threads) {
78
+ thread.join ();
79
+ }
80
+ if (lastException) {
81
+ std::rethrow_exception (lastException);
82
+ }
83
+ }
84
+
85
+
86
+ }
87
+
88
+
89
+ template <typename datatype>
90
+ std::vector<datatype> load_batch (std::string path, int size)
91
+ {
92
+ std::cout << " Loading " << path << " ..." ;
93
+ // float or int32 (python)
94
+ assert (sizeof (datatype) == 4 );
95
+
96
+ std::ifstream file;
97
+ file.open (path);
98
+ if (!file.is_open ())
99
+ {
100
+ std::cout << " Cannot open " << path << " \n " ;
101
+ exit (1 );
102
+ }
103
+ std::vector<datatype> batch (size);
104
+
105
+ file.read ((char *)batch.data (), size * sizeof (float ));
106
+ std::cout << " DONE\n " ;
107
+ return batch;
108
+ }
109
+
110
+ template <typename d_type>
111
+ static float
112
+ test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
113
+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
114
+ {
115
+ size_t correct = 0 ;
116
+ size_t total = 0 ;
117
+ // uncomment to test in parallel mode:
118
+
119
+
120
+ for (int i = 0 ; i < qsize; i++)
121
+ {
122
+
123
+ std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn ((char *)(queries.data () + vecdim * i), K);
124
+ total += K;
125
+ while (result.size ())
126
+ {
127
+ if (answers[i].find (result.top ().second ) != answers[i].end ())
128
+ {
129
+ correct++;
130
+ }
131
+ else
132
+ {
133
+ }
134
+ result.pop ();
135
+ }
136
+ }
137
+ return 1 .0f * correct / total;
138
+ }
139
+
140
+ static void
141
+ test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
142
+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
143
+ {
144
+ std::vector<size_t > efs = {1 };
145
+ for (int i = k; i < 30 ; i++)
146
+ {
147
+ efs.push_back (i);
148
+ }
149
+ for (int i = 30 ; i < 400 ; i+=10 )
150
+ {
151
+ efs.push_back (i);
152
+ }
153
+ for (int i = 1000 ; i < 100000 ; i += 5000 )
154
+ {
155
+ efs.push_back (i);
156
+ }
157
+ std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
158
+ for (size_t ef : efs)
159
+ {
160
+ appr_alg.setEf (ef);
161
+
162
+ appr_alg.metric_hops =0 ;
163
+ appr_alg.metric_distance_computations =0 ;
164
+ StopW stopw = StopW ();
165
+
166
+ float recall = test_approx<float >(queries, qsize, appr_alg, vecdim, answers, k);
167
+ float time_us_per_query = stopw.getElapsedTimeMicro () / qsize;
168
+ float distance_comp_per_query = appr_alg.metric_distance_computations / (1 .0f * qsize);
169
+ float hops_per_query = appr_alg.metric_hops / (1 .0f * qsize);
170
+
171
+ std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
172
+ if (recall > 0.99 )
173
+ {
174
+ std::cout << " Recall is over 0.99! " <<recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
175
+ break ;
176
+ }
177
+ }
178
+ }
179
+
180
+ int main (int argc, char **argv)
181
+ {
182
+
183
+ int M = 16 ;
184
+ int efConstruction = 200 ;
185
+ int num_threads = std::thread::hardware_concurrency ();
186
+
187
+
188
+
189
+ bool update = false ;
190
+
191
+ if (argc == 2 )
192
+ {
193
+ if (std::string (argv[1 ]) == " update" )
194
+ {
195
+ update = true ;
196
+ std::cout << " Updates are on\n " ;
197
+ }
198
+ else {
199
+ std::cout<<" Usage ./test_updates [update]\n " ;
200
+ exit (1 );
201
+ }
202
+ }
203
+ else if (argc>2 ){
204
+ std::cout<<" Usage ./test_updates [update]\n " ;
205
+ exit (1 );
206
+ }
207
+
208
+ std::string path = " ../examples/data/" ;
209
+
210
+
211
+ int N;
212
+ int dummy_data_multiplier;
213
+ int N_queries;
214
+ int d;
215
+ int K;
216
+ {
217
+ std::ifstream configfile;
218
+ configfile.open (path + " /config.txt" );
219
+ if (!configfile.is_open ())
220
+ {
221
+ std::cout << " Cannot open config.txt\n " ;
222
+ return 1 ;
223
+ }
224
+ configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K;
225
+
226
+ printf (" Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n " , N, dummy_data_multiplier, N_queries, d, K);
227
+ }
228
+
229
+ hnswlib::L2Space l2space (d);
230
+ hnswlib::HierarchicalNSW<float > appr_alg (&l2space, N + 1 , M, efConstruction);
231
+
232
+ std::vector<float > dummy_batch = load_batch<float >(path + " batch_dummy_00.bin" , N * d);
233
+
234
+ // Adding enterpoint:
235
+
236
+ appr_alg.addPoint ((void *)dummy_batch.data (), (size_t )0 );
237
+
238
+ StopW stopw = StopW ();
239
+
240
+ if (update)
241
+ {
242
+ std::cout << " Update iteration 0\n " ;
243
+
244
+
245
+ ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
246
+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
247
+ });
248
+ appr_alg.checkIntegrity ();
249
+
250
+ ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
251
+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
252
+ });
253
+ appr_alg.checkIntegrity ();
254
+
255
+ for (int b = 1 ; b < dummy_data_multiplier; b++)
256
+ {
257
+ std::cout << " Update iteration " << b << " \n " ;
258
+ char cpath[1024 ];
259
+ sprintf (cpath, " batch_dummy_%02d.bin" , b);
260
+ std::vector<float > dummy_batchb = load_batch<float >(path + cpath, N * d);
261
+
262
+ ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
263
+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
264
+ });
265
+ appr_alg.checkIntegrity ();
266
+ }
267
+ }
268
+
269
+ std::cout << " Inserting final elements\n " ;
270
+ std::vector<float > final_batch = load_batch<float >(path + " batch_final.bin" , N * d);
271
+
272
+ stopw.reset ();
273
+ ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
274
+ appr_alg.addPoint ((void *)(final_batch.data () + i * d), i);
275
+ });
276
+ std::cout<<" Finished. Time taken:" << stopw.getElapsedTimeMicro ()*1e-6 << " s\n " ;
277
+ std::cout << " Running tests\n " ;
278
+ std::vector<float > queries_batch = load_batch<float >(path + " queries.bin" , N_queries * d);
279
+
280
+ std::vector<int > gt = load_batch<int >(path + " gt.bin" , N_queries * K);
281
+
282
+ std::vector<std::unordered_set<hnswlib::labeltype>> answers (N_queries);
283
+ for (int i = 0 ; i < N_queries; i++)
284
+ {
285
+ for (int j = 0 ; j < K; j++)
286
+ {
287
+ answers[i].insert (gt[i * K + j]);
288
+ }
289
+ }
290
+
291
+ for (int i = 0 ; i < 3 ; i++)
292
+ {
293
+ std::cout << " Test iteration " << i << " \n " ;
294
+ test_vs_recall (queries_batch, N_queries, appr_alg, d, answers, K);
295
+ }
296
+
297
+ return 0 ;
298
+ };
0 commit comments