Skip to content

Remove temlate interface searchKnn #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ endif()

add_executable(test_updates examples/updates_test.cpp)

add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)

target_link_libraries(main sift_test)
84 changes: 84 additions & 0 deletions examples/searchKnnCloserFirst_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// This is a test file for testing the interface
// >>> virtual std::vector<std::pair<dist_t, labeltype>>
// >>> searchKnnCloserFirst(const void* query_data, size_t k) const;
// of class AlgorithmInterface

#include "../hnswlib/hnswlib.h"

#include <assert.h>

#include <vector>
#include <iostream>

namespace
{

using idx_t = hnswlib::labeltype;

void test() {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; ++i) {
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}


hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
alg_brute->addPoint(data.data() + d * i, i);
alg_hnsw->addPoint(data.data() + d * i, i);
}

// test searchKnnCloserFirst of BruteforceSearch
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k);
auto res = alg_brute->searchKnnCloserFirst(p, k);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
gd.pop();
}
}
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k);
auto res = alg_hnsw->searchKnnCloserFirst(p, k);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}

} // namespace

int main() {
std::cout << "Testing ..." << std::endl;
test();
std::cout << "Test ok" << std::endl;

return 0;
}
18 changes: 0 additions & 18 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,6 @@ namespace hnswlib {
return topResults;
};

template <typename Comp>
std::vector<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, Comp comp) {
std::vector<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;

auto ret = searchKnn(query_data, k);

while (!ret.empty()) {
result.push_back(ret.top());
ret.pop();
}

std::sort(result.begin(), result.end(), comp);

return result;
}

void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;
Expand Down
19 changes: 0 additions & 19 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <unordered_set>
#include <list>


namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;
Expand Down Expand Up @@ -1156,24 +1155,6 @@ namespace hnswlib {
return result;
};

template <typename Comp>
std::vector<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, Comp comp) {
std::vector<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;

auto ret = searchKnn(query_data, k);

while (!ret.empty()) {
result.push_back(ret.top());
ret.pop();
}

std::sort(result.begin(), result.end(), comp);

return result;
}

void checkIntegrity(){
int connections_checked=0;
std::vector <int > inbound_connections_num(cur_element_count,0);
Expand Down
26 changes: 23 additions & 3 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,34 @@ namespace hnswlib {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
template <typename Comp>
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
}

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k) const;

virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}

return result;
}

}

Expand Down