Skip to content

Perf improvement for dimension not of factor 4 and 16 #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,36 @@ namespace hnswlib {

#endif

#if defined(USE_SSE) || defined(USE_AVX)
static float
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return res + res_tail - 1.0f;
}

static float
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;

float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);

return res + res_tail - 1.0f;
}
#endif

class InnerProductSpace : public SpaceInterface<float> {

DISTFUNC<float> fstdistfunc_;
Expand All @@ -220,11 +250,15 @@ namespace hnswlib {
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProduct;
#if defined(USE_AVX) || defined(USE_SSE)
if (dim % 4 == 0)
fstdistfunc_ = InnerProductSIMD4Ext;
if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
#endif
else if (dim % 4 == 0)
fstdistfunc_ = InnerProductSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = InnerProductSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = InnerProductSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
Expand Down
79 changes: 50 additions & 29 deletions hnswlib/space_l2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
namespace hnswlib {

static float
L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
//return *((float *)pVect2);
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);

float res = 0;
for (unsigned i = 0; i < qty; i++) {
float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
for (size_t i = 0; i < qty; i++) {
float t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
}
return (res);

}

#if defined(USE_AVX)
Expand Down Expand Up @@ -49,10 +52,8 @@ namespace hnswlib {
}

_mm256_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];

return (res);
}
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}

#elif defined(USE_SSE)

Expand All @@ -62,12 +63,9 @@ namespace hnswlib {
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
// size_t qty4 = qty >> 2;
size_t qty16 = qty >> 4;

const float *pEnd1 = pVect1 + (qty16 << 4);
// const float* pEnd2 = pVect1 + (qty4 << 2);
// const float* pEnd3 = pVect1 + qty;

__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand Down Expand Up @@ -102,10 +100,24 @@ namespace hnswlib {
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}

_mm_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
#endif

return (res);
#if defined(USE_SSE) || defined(USE_AVX)
static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif

Expand All @@ -119,10 +131,9 @@ namespace hnswlib {
size_t qty = *((size_t *) qty_ptr);


// size_t qty4 = qty >> 2;
size_t qty16 = qty >> 2;
size_t qty4 = qty >> 2;

const float *pEnd1 = pVect1 + (qty16 << 2);
const float *pEnd1 = pVect1 + (qty4 << 2);

__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand All @@ -136,9 +147,22 @@ namespace hnswlib {
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}

return (res);
static float
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;

float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);

return (res + res_tail);
}
#endif

Expand All @@ -151,13 +175,14 @@ namespace hnswlib {
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX)
if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
/*else{
throw runtime_error("Data type not supported!");
}*/
else if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
Expand Down Expand Up @@ -185,10 +210,6 @@ namespace hnswlib {
int res = 0;
unsigned char *a = (unsigned char *) pVect1;
unsigned char *b = (unsigned char *) pVect2;
/*for (int i = 0; i < qty; i++) {
int t = int((a)[i]) - int((b)[i]);
res += t*t;
}*/

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
Expand Down Expand Up @@ -241,4 +262,4 @@ namespace hnswlib {
};


}
}