diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index e9467473..d0497ff7 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -211,6 +211,36 @@ namespace hnswlib { #endif +#if defined(USE_SSE) || defined(USE_AVX) + static float + InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return res + res_tail - 1.0f; + } + + static float + InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + + return res + res_tail - 1.0f; + } +#endif + class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; @@ -220,11 +250,15 @@ namespace hnswlib { InnerProductSpace(size_t dim) { fstdistfunc_ = InnerProduct; #if defined(USE_AVX) || defined(USE_SSE) - if (dim % 4 == 0) - fstdistfunc_ = InnerProductSIMD4Ext; if (dim % 16 == 0) fstdistfunc_ = InnerProductSIMD16Ext; -#endif + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductSIMD4ExtResiduals; + #endif dim_ = dim; data_size_ = dim * sizeof(float); } diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 4d3ac69a..bc00af72 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -4,16 +4,19 @@ namespace hnswlib { static float - L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { - //return *((float *)pVect2); + L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); + float res = 0; - for (unsigned i = 0; i < qty; i++) { - float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; res += t * t; } return (res); - } #if defined(USE_AVX) @@ -49,10 +52,8 @@ namespace hnswlib { } _mm256_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; - - return (res); -} + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + } #elif defined(USE_SSE) @@ -62,12 +63,9 @@ namespace hnswlib { float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); float PORTABLE_ALIGN32 TmpRes[8]; - // size_t qty4 = qty >> 2; size_t qty16 = qty >> 4; const float *pEnd1 = pVect1 + (qty16 << 4); - // const float* pEnd2 = pVect1 + (qty4 << 2); - // const float* pEnd3 = pVect1 + qty; __m128 diff, v1, v2; __m128 sum = _mm_set1_ps(0); @@ -102,10 +100,24 @@ namespace hnswlib { diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + _mm_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + } +#endif - return (res); +#if defined(USE_SSE) || defined(USE_AVX) + static float + L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); } #endif @@ -119,10 +131,9 @@ namespace hnswlib { size_t qty = *((size_t *) qty_ptr); - // size_t qty4 = qty >> 2; - size_t qty16 = qty >> 2; + size_t qty4 = qty >> 2; - const float *pEnd1 = pVect1 + (qty16 << 2); + const float *pEnd1 = pVect1 + (qty4 << 2); __m128 diff, v1, v2; __m128 sum = _mm_set1_ps(0); @@ -136,9 +147,22 @@ namespace hnswlib { sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } _mm_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + } - return (res); + static float + L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + + return (res + res_tail); } #endif @@ -151,13 +175,14 @@ namespace hnswlib { L2Space(size_t dim) { fstdistfunc_ = L2Sqr; #if defined(USE_SSE) || defined(USE_AVX) - if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; if (dim % 16 == 0) fstdistfunc_ = L2SqrSIMD16Ext; - /*else{ - throw runtime_error("Data type not supported!"); - }*/ + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; #endif dim_ = dim; data_size_ = dim * sizeof(float); @@ -185,10 +210,6 @@ namespace hnswlib { int res = 0; unsigned char *a = (unsigned char *) pVect1; unsigned char *b = (unsigned char *) pVect2; - /*for (int i = 0; i < qty; i++) { - int t = int((a)[i]) - int((b)[i]); - res += t*t; - }*/ qty = qty >> 2; for (size_t i = 0; i < qty; i++) { @@ -241,4 +262,4 @@ namespace hnswlib { }; -} +} \ No newline at end of file