diff --git a/CMakeLists.txt b/CMakeLists.txt index ebee6e6c..31935e0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,4 +23,6 @@ endif() add_executable(test_updates examples/updates_test.cpp) +add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) + target_link_libraries(main sift_test) diff --git a/examples/searchKnnCloserFirst_test.cpp b/examples/searchKnnCloserFirst_test.cpp new file mode 100644 index 00000000..cc1392c8 --- /dev/null +++ b/examples/searchKnnCloserFirst_test.cpp @@ -0,0 +1,84 @@ +// This is a test file for testing the interface +// >>> virtual std::vector> +// >>> searchKnnCloserFirst(const void* query_data, size_t k) const; +// of class AlgorithmInterface + +#include "../hnswlib/hnswlib.h" + +#include + +#include +#include + +namespace +{ + +using idx_t = hnswlib::labeltype; + +void test() { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + alg_brute->addPoint(data.data() + d * i, i); + alg_hnsw->addPoint(data.data() + d * i, i); + } + + // test searchKnnCloserFirst of BruteforceSearch + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k); + auto res = alg_brute->searchKnnCloserFirst(p, k); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + gd.pop(); + } + } + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k); + auto res = alg_hnsw->searchKnnCloserFirst(p, k); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + gd.pop(); + } + } + + delete alg_brute; + delete alg_hnsw; +} + +} // namespace + +int main() { + std::cout << "Testing ..." << std::endl; + test(); + std::cout << "Test ok" << std::endl; + + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 5b1bd655..24260400 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -111,24 +111,6 @@ namespace hnswlib { return topResults; }; - template - std::vector> - searchKnn(const void* query_data, size_t k, Comp comp) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn(query_data, k); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); - } - - std::sort(result.begin(), result.end(), comp); - - return result; - } - void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); std::streampos position; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 06ca7993..a2f72dc7 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -9,7 +9,6 @@ #include #include - namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; @@ -1156,24 +1155,6 @@ namespace hnswlib { return result; }; - template - std::vector> - searchKnn(const void* query_data, size_t k, Comp comp) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn(query_data, k); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); - } - - std::sort(result.begin(), result.end(), comp); - - return result; - } - void checkIntegrity(){ int connections_checked=0; std::vector inbound_connections_num(cur_element_count,0); diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index c26f80b5..9409c388 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -71,14 +71,34 @@ namespace hnswlib { public: virtual void addPoint(const void *datapoint, labeltype label)=0; virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; - template - std::vector> searchKnn(const void*, size_t, Comp) { - } + + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k) const; + virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } }; + template + std::vector> + AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; + } }