Skip to content

Commit 6bd7e37

Browse files
committed
[L2Space] Perf improvement for dimension not of factor 4 and 16
Currently SIMD (SSE or AVX) is used for the cases when dimension is multiple of 4 or 16, while when dimension size is not strictly equal to multiple of 4 or 16 a slower non-vectorized method is used. To improve performnance for these cases new methods are added: `L2SqrSIMD(4|16)ExtResidual` - it relies on existing `L2SqrSIMD(4|16)Ext` to compute up to *4 and *16 dimensions and finishes residual computation by relying on non-vectorized method `L2Sqr`.
1 parent a97ec89 commit 6bd7e37

File tree

1 file changed

+50
-29
lines changed

1 file changed

+50
-29
lines changed

hnswlib/space_l2.h

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
namespace hnswlib {
55

66
static float
7-
L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
8-
//return *((float *)pVect2);
7+
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8+
float *pVect1 = (float *) pVect1v;
9+
float *pVect2 = (float *) pVect2v;
910
size_t qty = *((size_t *) qty_ptr);
11+
1012
float res = 0;
11-
for (unsigned i = 0; i < qty; i++) {
12-
float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
13+
for (size_t i = 0; i < qty; i++) {
14+
float t = *pVect1 - *pVect2;
15+
pVect1++;
16+
pVect2++;
1317
res += t * t;
1418
}
1519
return (res);
16-
1720
}
1821

1922
#if defined(USE_AVX)
@@ -49,10 +52,8 @@ namespace hnswlib {
4952
}
5053

5154
_mm256_store_ps(TmpRes, sum);
52-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
53-
54-
return (res);
55-
}
55+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
56+
}
5657

5758
#elif defined(USE_SSE)
5859

@@ -62,12 +63,9 @@ namespace hnswlib {
6263
float *pVect2 = (float *) pVect2v;
6364
size_t qty = *((size_t *) qty_ptr);
6465
float PORTABLE_ALIGN32 TmpRes[8];
65-
// size_t qty4 = qty >> 2;
6666
size_t qty16 = qty >> 4;
6767

6868
const float *pEnd1 = pVect1 + (qty16 << 4);
69-
// const float* pEnd2 = pVect1 + (qty4 << 2);
70-
// const float* pEnd3 = pVect1 + qty;
7169

7270
__m128 diff, v1, v2;
7371
__m128 sum = _mm_set1_ps(0);
@@ -102,10 +100,24 @@ namespace hnswlib {
102100
diff = _mm_sub_ps(v1, v2);
103101
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
104102
}
103+
105104
_mm_store_ps(TmpRes, sum);
106-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
105+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
106+
}
107+
#endif
107108

108-
return (res);
109+
#if defined(USE_SSE) || defined(USE_AVX)
110+
static float
111+
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
112+
size_t qty = *((size_t *) qty_ptr);
113+
size_t qty16 = qty >> 4 << 4;
114+
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
115+
float *pVect1 = (float *) pVect1v + qty16;
116+
float *pVect2 = (float *) pVect2v + qty16;
117+
118+
size_t qty_left = qty - qty16;
119+
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
120+
return (res + res_tail);
109121
}
110122
#endif
111123

@@ -119,10 +131,9 @@ namespace hnswlib {
119131
size_t qty = *((size_t *) qty_ptr);
120132

121133

122-
// size_t qty4 = qty >> 2;
123-
size_t qty16 = qty >> 2;
134+
size_t qty4 = qty >> 2;
124135

125-
const float *pEnd1 = pVect1 + (qty16 << 2);
136+
const float *pEnd1 = pVect1 + (qty4 << 2);
126137

127138
__m128 diff, v1, v2;
128139
__m128 sum = _mm_set1_ps(0);
@@ -136,9 +147,22 @@ namespace hnswlib {
136147
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
137148
}
138149
_mm_store_ps(TmpRes, sum);
139-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
150+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
151+
}
140152

141-
return (res);
153+
static float
154+
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
155+
size_t qty = *((size_t *) qty_ptr);
156+
size_t qty4 = qty >> 2 << 2;
157+
158+
float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
159+
size_t qty_left = qty - qty4;
160+
161+
float *pVect1 = (float *) pVect1v + qty4;
162+
float *pVect2 = (float *) pVect2v + qty4;
163+
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
164+
165+
return (res + res_tail);
142166
}
143167
#endif
144168

@@ -151,13 +175,14 @@ namespace hnswlib {
151175
L2Space(size_t dim) {
152176
fstdistfunc_ = L2Sqr;
153177
#if defined(USE_SSE) || defined(USE_AVX)
154-
if (dim % 4 == 0)
155-
fstdistfunc_ = L2SqrSIMD4Ext;
156178
if (dim % 16 == 0)
157179
fstdistfunc_ = L2SqrSIMD16Ext;
158-
/*else{
159-
throw runtime_error("Data type not supported!");
160-
}*/
180+
else if (dim % 4 == 0)
181+
fstdistfunc_ = L2SqrSIMD4Ext;
182+
else if (dim > 16)
183+
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
184+
else if (dim > 4)
185+
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
161186
#endif
162187
dim_ = dim;
163188
data_size_ = dim * sizeof(float);
@@ -185,10 +210,6 @@ namespace hnswlib {
185210
int res = 0;
186211
unsigned char *a = (unsigned char *) pVect1;
187212
unsigned char *b = (unsigned char *) pVect2;
188-
/*for (int i = 0; i < qty; i++) {
189-
int t = int((a)[i]) - int((b)[i]);
190-
res += t*t;
191-
}*/
192213

193214
qty = qty >> 2;
194215
for (size_t i = 0; i < qty; i++) {
@@ -241,4 +262,4 @@ namespace hnswlib {
241262
};
242263

243264

244-
}
265+
}

0 commit comments

Comments
 (0)