Skip to content

Commit c247540

Browse files
authored
Merge pull request #216 from apoorv-sharma/update_patch
Algorithm to perform dynamic/incremental updates of feature vectors
2 parents ba16931 + 524873b commit c247540

File tree

6 files changed

+666
-73
lines changed

6 files changed

+666
-73
lines changed

CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@ include_directories("${PROJECT_BINARY_DIR}")
55

66

77

8-
set(SOURCE_EXE main.cpp)
8+
set(SOURCE_EXE main.cpp)
99

1010
set(SOURCE_LIB sift_1b.cpp)
1111

1212
add_library(sift_test STATIC ${SOURCE_LIB})
1313

1414

1515
add_executable(main ${SOURCE_EXE})
16+
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
17+
SET( CMAKE_CXX_FLAGS "-Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -ftree-vectorize")
18+
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
1619
SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" )
20+
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
21+
SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" )
22+
endif()
23+
24+
add_executable(test_updates examples/updates_test.cpp)
25+
1726
target_link_libraries(main sift_test)

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,29 @@ To run the test on 200M SIFT subset:
223223

224224
The size of the bigann subset (in millions) is controlled by the variable **subset_size_milllions** hardcoded in **sift_1b.cpp**.
225225

226+
### Updates test
227+
To generate testing data (from root directory):
228+
```bash
229+
cd examples
230+
python update_gen_data.py
231+
```
232+
To compile (from root directory):
233+
```bash
234+
mkdir build
235+
cd build
236+
cmake ..
237+
make
238+
```
239+
To run test **without** updates (from `build` directory)
240+
```bash
241+
./test_updates
242+
```
243+
244+
To run test **with** updates (from `build` directory)
245+
```bash
246+
./test_updates update
247+
```
248+
226249
### HNSW example demos
227250

228251
- Visual search engine for 1M amazon products (MXNet + HNSW): [website](https://thomasdelteil.github.io/VisualSearch_MXNet/), [code](https://github.com/ThomasDelteil/VisualSearch_MXNet), demo by [@ThomasDelteil](https://github.com/ThomasDelteil)

examples/update_gen_data.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import os
3+
4+
def normalized(a, axis=-1, order=2):
5+
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
6+
l2[l2==0] = 1
7+
return a / np.expand_dims(l2, axis)
8+
9+
N=100000
10+
dummy_data_multiplier=3
11+
N_queries = 1000
12+
d=8
13+
K=5
14+
15+
np.random.seed(1)
16+
17+
print("Generating data...")
18+
batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)]
19+
batch_final = normalized (np.float32(np.random.random( (N,d))))
20+
queries = normalized(np.float32(np.random.random( (N_queries,d))))
21+
print("Computing distances...")
22+
dist=np.dot(queries,batch_final.T)
23+
topk=np.argsort(-dist)[:,:K]
24+
print("Saving...")
25+
26+
try:
27+
os.mkdir("data")
28+
except OSError as e:
29+
pass
30+
31+
for idx, batch_dummy in enumerate(batches_dummy):
32+
batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx)
33+
batch_final.tofile('data/batch_final.bin')
34+
queries.tofile('data/queries.bin')
35+
np.int32(topk).tofile('data/gt.bin')
36+
with open("data/config.txt", "w") as file:
37+
file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K))

examples/updates_test.cpp

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
#include "../hnswlib/hnswlib.h"
2+
#include <thread>
3+
class StopW
4+
{
5+
std::chrono::steady_clock::time_point time_begin;
6+
7+
public:
8+
StopW()
9+
{
10+
time_begin = std::chrono::steady_clock::now();
11+
}
12+
13+
float getElapsedTimeMicro()
14+
{
15+
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
16+
return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count());
17+
}
18+
19+
void reset()
20+
{
21+
time_begin = std::chrono::steady_clock::now();
22+
}
23+
};
24+
25+
/*
26+
* replacement for the openmp '#pragma omp parallel for' directive
27+
* only handles a subset of functionality (no reductions etc)
28+
* Process ids from start (inclusive) to end (EXCLUSIVE)
29+
*
30+
* The method is borrowed from nmslib
31+
*/
32+
template<class Function>
33+
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
34+
if (numThreads <= 0) {
35+
numThreads = std::thread::hardware_concurrency();
36+
}
37+
38+
if (numThreads == 1) {
39+
for (size_t id = start; id < end; id++) {
40+
fn(id, 0);
41+
}
42+
} else {
43+
std::vector<std::thread> threads;
44+
std::atomic<size_t> current(start);
45+
46+
// keep track of exceptions in threads
47+
// https://stackoverflow.com/a/32428427/1713196
48+
std::exception_ptr lastException = nullptr;
49+
std::mutex lastExceptMutex;
50+
51+
for (size_t threadId = 0; threadId < numThreads; ++threadId) {
52+
threads.push_back(std::thread([&, threadId] {
53+
while (true) {
54+
size_t id = current.fetch_add(1);
55+
56+
if ((id >= end)) {
57+
break;
58+
}
59+
60+
try {
61+
fn(id, threadId);
62+
} catch (...) {
63+
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
64+
lastException = std::current_exception();
65+
/*
66+
* This will work even when current is the largest value that
67+
* size_t can fit, because fetch_add returns the previous value
68+
* before the increment (what will result in overflow
69+
* and produce 0 instead of current + 1).
70+
*/
71+
current = end;
72+
break;
73+
}
74+
}
75+
}));
76+
}
77+
for (auto &thread : threads) {
78+
thread.join();
79+
}
80+
if (lastException) {
81+
std::rethrow_exception(lastException);
82+
}
83+
}
84+
85+
86+
}
87+
88+
89+
template <typename datatype>
90+
std::vector<datatype> load_batch(std::string path, int size)
91+
{
92+
std::cout << "Loading " << path << "...";
93+
// float or int32 (python)
94+
assert(sizeof(datatype) == 4);
95+
96+
std::ifstream file;
97+
file.open(path);
98+
if (!file.is_open())
99+
{
100+
std::cout << "Cannot open " << path << "\n";
101+
exit(1);
102+
}
103+
std::vector<datatype> batch(size);
104+
105+
file.read((char *)batch.data(), size * sizeof(float));
106+
std::cout << " DONE\n";
107+
return batch;
108+
}
109+
110+
template <typename d_type>
111+
static float
112+
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
113+
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
114+
{
115+
size_t correct = 0;
116+
size_t total = 0;
117+
//uncomment to test in parallel mode:
118+
119+
120+
for (int i = 0; i < qsize; i++)
121+
{
122+
123+
std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K);
124+
total += K;
125+
while (result.size())
126+
{
127+
if (answers[i].find(result.top().second) != answers[i].end())
128+
{
129+
correct++;
130+
}
131+
else
132+
{
133+
}
134+
result.pop();
135+
}
136+
}
137+
return 1.0f * correct / total;
138+
}
139+
140+
static void
141+
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
142+
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
143+
{
144+
std::vector<size_t> efs = {1};
145+
for (int i = k; i < 30; i++)
146+
{
147+
efs.push_back(i);
148+
}
149+
for (int i = 30; i < 400; i+=10)
150+
{
151+
efs.push_back(i);
152+
}
153+
for (int i = 1000; i < 100000; i += 5000)
154+
{
155+
efs.push_back(i);
156+
}
157+
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
158+
for (size_t ef : efs)
159+
{
160+
appr_alg.setEf(ef);
161+
162+
appr_alg.metric_hops=0;
163+
appr_alg.metric_distance_computations=0;
164+
StopW stopw = StopW();
165+
166+
float recall = test_approx<float>(queries, qsize, appr_alg, vecdim, answers, k);
167+
float time_us_per_query = stopw.getElapsedTimeMicro() / qsize;
168+
float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize);
169+
float hops_per_query = appr_alg.metric_hops / (1.0f * qsize);
170+
171+
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
172+
if (recall > 0.99)
173+
{
174+
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
175+
break;
176+
}
177+
}
178+
}
179+
180+
int main(int argc, char **argv)
181+
{
182+
183+
int M = 16;
184+
int efConstruction = 200;
185+
int num_threads = std::thread::hardware_concurrency();
186+
187+
188+
189+
bool update = false;
190+
191+
if (argc == 2)
192+
{
193+
if (std::string(argv[1]) == "update")
194+
{
195+
update = true;
196+
std::cout << "Updates are on\n";
197+
}
198+
else {
199+
std::cout<<"Usage ./test_updates [update]\n";
200+
exit(1);
201+
}
202+
}
203+
else if (argc>2){
204+
std::cout<<"Usage ./test_updates [update]\n";
205+
exit(1);
206+
}
207+
208+
std::string path = "../examples/data/";
209+
210+
211+
int N;
212+
int dummy_data_multiplier;
213+
int N_queries;
214+
int d;
215+
int K;
216+
{
217+
std::ifstream configfile;
218+
configfile.open(path + "/config.txt");
219+
if (!configfile.is_open())
220+
{
221+
std::cout << "Cannot open config.txt\n";
222+
return 1;
223+
}
224+
configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K;
225+
226+
printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K);
227+
}
228+
229+
hnswlib::L2Space l2space(d);
230+
hnswlib::HierarchicalNSW<float> appr_alg(&l2space, N + 1, M, efConstruction);
231+
232+
std::vector<float> dummy_batch = load_batch<float>(path + "batch_dummy_00.bin", N * d);
233+
234+
// Adding enterpoint:
235+
236+
appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0);
237+
238+
StopW stopw = StopW();
239+
240+
if (update)
241+
{
242+
std::cout << "Update iteration 0\n";
243+
244+
245+
ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
246+
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
247+
});
248+
appr_alg.checkIntegrity();
249+
250+
ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
251+
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
252+
});
253+
appr_alg.checkIntegrity();
254+
255+
for (int b = 1; b < dummy_data_multiplier; b++)
256+
{
257+
std::cout << "Update iteration " << b << "\n";
258+
char cpath[1024];
259+
sprintf(cpath, "batch_dummy_%02d.bin", b);
260+
std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);
261+
262+
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
263+
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
264+
});
265+
appr_alg.checkIntegrity();
266+
}
267+
}
268+
269+
std::cout << "Inserting final elements\n";
270+
std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);
271+
272+
stopw.reset();
273+
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
274+
appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
275+
});
276+
std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
277+
std::cout << "Running tests\n";
278+
std::vector<float> queries_batch = load_batch<float>(path + "queries.bin", N_queries * d);
279+
280+
std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);
281+
282+
std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
283+
for (int i = 0; i < N_queries; i++)
284+
{
285+
for (int j = 0; j < K; j++)
286+
{
287+
answers[i].insert(gt[i * K + j]);
288+
}
289+
}
290+
291+
for (int i = 0; i < 3; i++)
292+
{
293+
std::cout << "Test iteration " << i << "\n";
294+
test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
295+
}
296+
297+
return 0;
298+
};

0 commit comments

Comments
 (0)