1
+ #include < assert.h>
2
+ #include " ../benchmark/include/benchmark/benchmark.h"
3
+ #include < vector>
4
+ #include < string>
5
+ #include < fstream>
6
+ #include < sstream>
7
+ #include < iostream>
8
+ #include " hnswlib.h"
9
+
10
+ class L2BaseBench : public benchmark ::Fixture {
11
+ public:
12
+ size_t dimension;
13
+ size_t nb_embeddings = 100 ;
14
+
15
+ hnswlib::L2Space* space;
16
+ void * dist_func_param;
17
+ hnswlib::DISTFUNC<float > dist_func;
18
+ std::vector<std::vector<float >> scenario_input;
19
+ size_t scenario_input_index;
20
+ size_t scenario_input_size;
21
+
22
+ const float RAND_MAX_FLOAT = (float )(RAND_MAX);
23
+
24
+ void SetUp (const ::benchmark::State& state) {
25
+ scenario_input.clear ();
26
+ scenario_input.reserve (nb_embeddings);
27
+ for (int i = 0 ; i < nb_embeddings; i++) {
28
+ std::vector<float > vector (dimension);
29
+ for (int j = 0 ; j < dimension; j++) {
30
+ vector[j] = (float )rand ()/RAND_MAX_FLOAT;
31
+ }
32
+ scenario_input.push_back (vector);
33
+ }
34
+ space = new hnswlib::L2Space (dimension);
35
+ dist_func_param = space->get_dist_func_param ();
36
+ dist_func = space->get_dist_func ();
37
+
38
+ scenario_input_size = scenario_input.size ();
39
+ scenario_input_index = 0 ;
40
+ }
41
+
42
+ void TearDown (const ::benchmark::State& state) {
43
+ delete space;
44
+ }
45
+
46
+ float * get_vector () {
47
+ float * vector_data = scenario_input[scenario_input_index].data ();
48
+ scenario_input_index++;
49
+ if (scenario_input_index == scenario_input_size) {
50
+ scenario_input_index = 0 ;
51
+ }
52
+ return vector_data;
53
+ }
54
+
55
+ float compute_distance () {
56
+ auto vector1 = get_vector ();
57
+ auto vector2 = get_vector ();
58
+ return dist_func (vector1, vector2, dist_func_param);
59
+ }
60
+ };
61
+
62
+ #define L2DimBench (dim )\
63
+ class Dim ##dim : public L2BaseBench {\
64
+ public: Dim##dim() {\
65
+ dimension = dim;\
66
+ }\
67
+ };\
68
+ BENCHMARK_DEFINE_F (Dim##dim, Dist)(benchmark::State& st) { for (auto _ : st) compute_distance (); }\
69
+ BENCHMARK_REGISTER_F (Dim##dim, Dist);
70
+
71
+ L2DimBench (3 );
72
+ L2DimBench (4 );
73
+ L2DimBench (7 );
74
+ L2DimBench (8 );
75
+ L2DimBench (9 );
76
+ L2DimBench (15 );
77
+ L2DimBench (16 );
78
+ L2DimBench (100 );
79
+ L2DimBench (101 );
80
+ L2DimBench (128 );
81
+ L2DimBench (129 );
82
+
83
+ BENCHMARK_DEFINE_F (Dim3, Prep)(benchmark::State& st) {
84
+ for (auto _ : st) {
85
+ get_vector ();
86
+ get_vector ();
87
+ }
88
+ }
89
+
90
+ BENCHMARK_MAIN ();
0 commit comments