4
4
namespace hnswlib {
5
5
6
6
static float
7
- L2Sqr (const void *pVect1, const void *pVect2, const void *qty_ptr) {
8
- // return *((float *)pVect2);
7
+ L2Sqr (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8
+ float *pVect1 = (float *) pVect1v;
9
+ float *pVect2 = (float *) pVect2v;
9
10
size_t qty = *((size_t *) qty_ptr);
11
+
10
12
float res = 0 ;
11
- for (unsigned i = 0 ; i < qty; i++) {
12
- float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
13
+ for (size_t i = 0 ; i < qty; i++) {
14
+ float t = *pVect1 - *pVect2;
15
+ pVect1++;
16
+ pVect2++;
13
17
res += t * t;
14
18
}
15
19
return (res);
16
-
17
20
}
18
21
19
22
#if defined(USE_AVX)
@@ -49,10 +52,8 @@ namespace hnswlib {
49
52
}
50
53
51
54
_mm256_store_ps (TmpRes, sum);
52
- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
53
-
54
- return (res);
55
- }
55
+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
56
+ }
56
57
57
58
#elif defined(USE_SSE)
58
59
@@ -62,12 +63,9 @@ namespace hnswlib {
62
63
float *pVect2 = (float *) pVect2v;
63
64
size_t qty = *((size_t *) qty_ptr);
64
65
float PORTABLE_ALIGN32 TmpRes[8 ];
65
- // size_t qty4 = qty >> 2;
66
66
size_t qty16 = qty >> 4 ;
67
67
68
68
const float *pEnd1 = pVect1 + (qty16 << 4 );
69
- // const float* pEnd2 = pVect1 + (qty4 << 2);
70
- // const float* pEnd3 = pVect1 + qty;
71
69
72
70
__m128 diff, v1, v2;
73
71
__m128 sum = _mm_set1_ps (0 );
@@ -102,10 +100,24 @@ namespace hnswlib {
102
100
diff = _mm_sub_ps (v1, v2);
103
101
sum = _mm_add_ps (sum, _mm_mul_ps (diff, diff));
104
102
}
103
+
105
104
_mm_store_ps (TmpRes, sum);
106
- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
105
+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
106
+ }
107
+ #endif
107
108
108
- return (res);
109
+ #if defined(USE_SSE) || defined(USE_AVX)
110
+ static float
111
+ L2SqrSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
112
+ size_t qty = *((size_t *) qty_ptr);
113
+ size_t qty16 = qty >> 4 << 4 ;
114
+ float res = L2SqrSIMD16Ext (pVect1v, pVect2v, &qty16);
115
+ float *pVect1 = (float *) pVect1v + qty16;
116
+ float *pVect2 = (float *) pVect2v + qty16;
117
+
118
+ size_t qty_left = qty - qty16;
119
+ float res_tail = L2Sqr (pVect1, pVect2, &qty_left);
120
+ return (res + res_tail);
109
121
}
110
122
#endif
111
123
@@ -119,10 +131,9 @@ namespace hnswlib {
119
131
size_t qty = *((size_t *) qty_ptr);
120
132
121
133
122
- // size_t qty4 = qty >> 2;
123
- size_t qty16 = qty >> 2 ;
134
+ size_t qty4 = qty >> 2 ;
124
135
125
- const float *pEnd1 = pVect1 + (qty16 << 2 );
136
+ const float *pEnd1 = pVect1 + (qty4 << 2 );
126
137
127
138
__m128 diff, v1, v2;
128
139
__m128 sum = _mm_set1_ps (0 );
@@ -136,9 +147,22 @@ namespace hnswlib {
136
147
sum = _mm_add_ps (sum, _mm_mul_ps (diff, diff));
137
148
}
138
149
_mm_store_ps (TmpRes, sum);
139
- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
150
+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
151
+ }
140
152
141
- return (res);
153
+ static float
154
+ L2SqrSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
155
+ size_t qty = *((size_t *) qty_ptr);
156
+ size_t qty4 = qty >> 2 << 2 ;
157
+
158
+ float res = L2SqrSIMD4Ext (pVect1v, pVect2v, &qty4);
159
+ size_t qty_left = qty - qty4;
160
+
161
+ float *pVect1 = (float *) pVect1v + qty4;
162
+ float *pVect2 = (float *) pVect2v + qty4;
163
+ float res_tail = L2Sqr (pVect1, pVect2, &qty_left);
164
+
165
+ return (res + res_tail);
142
166
}
143
167
#endif
144
168
@@ -151,13 +175,14 @@ namespace hnswlib {
151
175
L2Space (size_t dim) {
152
176
fstdistfunc_ = L2Sqr;
153
177
#if defined(USE_SSE) || defined(USE_AVX)
154
- if (dim % 4 == 0 )
155
- fstdistfunc_ = L2SqrSIMD4Ext;
156
178
if (dim % 16 == 0 )
157
179
fstdistfunc_ = L2SqrSIMD16Ext;
158
- /* else{
159
- throw runtime_error("Data type not supported!");
160
- }*/
180
+ else if (dim % 4 == 0 )
181
+ fstdistfunc_ = L2SqrSIMD4Ext;
182
+ else if (dim > 16 )
183
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
184
+ else if (dim > 4 )
185
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
161
186
#endif
162
187
dim_ = dim;
163
188
data_size_ = dim * sizeof (float );
@@ -185,10 +210,6 @@ namespace hnswlib {
185
210
int res = 0 ;
186
211
unsigned char *a = (unsigned char *) pVect1;
187
212
unsigned char *b = (unsigned char *) pVect2;
188
- /* for (int i = 0; i < qty; i++) {
189
- int t = int((a)[i]) - int((b)[i]);
190
- res += t*t;
191
- }*/
192
213
193
214
qty = qty >> 2 ;
194
215
for (size_t i = 0 ; i < qty; i++) {
@@ -241,4 +262,4 @@ namespace hnswlib {
241
262
};
242
263
243
264
244
- }
265
+ }
0 commit comments