Skip to content

Commit 334cc6c

Browse files
authored
Merge pull request #258 from dbespalov/python_bindings_state_dict
Use dict for Index serialization
2 parents 8cc442d + 4c002bc commit 334cc6c

File tree

1 file changed

+128
-114
lines changed

1 file changed

+128
-114
lines changed

python_bindings/bindings.cpp

Lines changed: 128 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <assert.h>
1010

1111
namespace py = pybind11;
12+
using namespace pybind11::literals; // needed to bring in _a literal
1213

1314
/*
1415
* replacement for the openmp '#pragma omp parallel for' directive
@@ -73,6 +74,12 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
7374

7475
}
7576

77+
inline void assert_true(bool expr, const std::string & msg) {
78+
if (expr == false)
79+
throw std::runtime_error("Unpickle Error: "+msg);
80+
return;
81+
}
82+
7683

7784

7885
template<typename dist_t, typename data_t=float>
@@ -98,7 +105,7 @@ class Index {
98105

99106
default_ef=10;
100107
}
101-
108+
102109
static const int ser_version = 1; // serialization version
103110

104111
std::string space_name;
@@ -278,15 +285,11 @@ class Index {
278285
return ids;
279286
}
280287

281-
inline void assert_true(bool expr, const std::string & msg) {
282-
if (expr == false)
283-
throw std::runtime_error("assert failed: "+msg);
284-
return;
285-
}
288+
289+
py::dict getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
290+
286291

287292

288-
py::tuple getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
289-
290293
std::unique_lock <std::mutex> templock(appr_alg->global);
291294

292295
unsigned int level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_;
@@ -345,140 +348,153 @@ class Index {
345348
delete[] f;
346349
});
347350

348-
return py::make_tuple(appr_alg->offsetLevel0_,
349-
appr_alg->max_elements_,
350-
appr_alg->cur_element_count,
351-
appr_alg->size_data_per_element_,
352-
appr_alg->label_offset_,
353-
appr_alg->offsetData_,
354-
appr_alg->maxlevel_,
355-
appr_alg->enterpoint_node_,
356-
appr_alg->maxM_,
357-
appr_alg->maxM0_,
358-
appr_alg->M_,
359-
appr_alg->mult_,
360-
appr_alg->ef_construction_,
361-
appr_alg->ef_,
362-
appr_alg->has_deletions_,
363-
appr_alg->size_links_per_element_,
364-
py::array_t<hnswlib::labeltype>(
365-
{appr_alg->label_lookup_.size()}, // shape
366-
{sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double
367-
label_lookup_key_npy, // the data pointer
368-
free_when_done_lb),
369-
py::array_t<hnswlib::tableint>(
370-
{appr_alg->label_lookup_.size()}, // shape
371-
{sizeof(hnswlib::tableint)}, // C-style contiguous strides for double
372-
label_lookup_val_npy, // the data pointer
373-
free_when_done_id),
374-
py::array_t<int>(
375-
{appr_alg->element_levels_.size()}, // shape
376-
{sizeof(int)}, // C-style contiguous strides for double
377-
element_levels_npy, // the data pointer
378-
free_when_done_lvl),
379-
py::array_t<char>(
380-
{level0_npy_size}, // shape
381-
{sizeof(char)}, // C-style contiguous strides for double
382-
data_level0_npy, // the data pointer
383-
free_when_done_l0),
384-
py::array_t<char>(
385-
{link_npy_size}, // shape
386-
{sizeof(char)}, // C-style contiguous strides for double
387-
link_list_npy, // the data pointer
388-
free_when_done_ll)
389-
);
351+
/* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
352+
/* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
353+
354+
return py::dict(
355+
"offset_level0"_a=appr_alg->offsetLevel0_,
356+
"max_elements"_a=appr_alg->max_elements_,
357+
"cur_element_count"_a=appr_alg->cur_element_count,
358+
"size_data_per_element"_a=appr_alg->size_data_per_element_,
359+
"label_offset"_a=appr_alg->label_offset_,
360+
"offset_data"_a=appr_alg->offsetData_,
361+
"max_level"_a=appr_alg->maxlevel_,
362+
"enterpoint_node"_a=appr_alg->enterpoint_node_,
363+
"max_M"_a=appr_alg->maxM_,
364+
"max_M0"_a=appr_alg->maxM0_,
365+
"M"_a=appr_alg->M_,
366+
"mult"_a=appr_alg->mult_,
367+
"ef_construction"_a=appr_alg->ef_construction_,
368+
"ef"_a=appr_alg->ef_,
369+
"has_deletions"_a=appr_alg->has_deletions_,
370+
"size_links_per_element"_a=appr_alg->size_links_per_element_,
371+
372+
"label_lookup_external"_a=py::array_t<hnswlib::labeltype>(
373+
{appr_alg->label_lookup_.size()}, // shape
374+
{sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double
375+
label_lookup_key_npy, // the data pointer
376+
free_when_done_lb),
377+
378+
"label_lookup_internal"_a=py::array_t<hnswlib::tableint>(
379+
{appr_alg->label_lookup_.size()}, // shape
380+
{sizeof(hnswlib::tableint)}, // C-style contiguous strides for double
381+
label_lookup_val_npy, // the data pointer
382+
free_when_done_id),
383+
384+
"element_levels"_a=py::array_t<int>(
385+
{appr_alg->element_levels_.size()}, // shape
386+
{sizeof(int)}, // C-style contiguous strides for double
387+
element_levels_npy, // the data pointer
388+
free_when_done_lvl),
389+
390+
// linkLists_,element_levels_,data_level0_memory_
391+
"data_level0"_a=py::array_t<char>(
392+
{level0_npy_size}, // shape
393+
{sizeof(char)}, // C-style contiguous strides for double
394+
data_level0_npy, // the data pointer
395+
free_when_done_l0),
396+
397+
"link_lists"_a=py::array_t<char>(
398+
{link_npy_size}, // shape
399+
{sizeof(char)}, // C-style contiguous strides for double
400+
link_list_npy, // the data pointer
401+
free_when_done_ll)
402+
403+
);
404+
390405

391406
}
392407

393408

394-
py::tuple getIndexParams() const {
395-
/* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
396-
/* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
397-
398-
return py::make_tuple(py::int_(Index<float>::ser_version), // serialization version
399-
400-
/* TODO: convert the following two py::tuple's to py::dict */
401-
py::make_tuple(space_name, dim, index_inited, ep_added, normalize, num_threads_default, seed, default_ef),
402-
index_inited == true ? getAnnData() : py::make_tuple()); /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
403-
404-
409+
py::dict getIndexParams() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
410+
auto params = py::dict(
411+
"ser_version"_a=py::int_(Index<float>::ser_version), //serialization version
412+
"space"_a=space_name,
413+
"dim"_a=dim,
414+
"index_inited"_a=index_inited,
415+
"ep_added"_a=ep_added,
416+
"normalize"_a=normalize,
417+
"num_threads"_a=num_threads_default,
418+
"seed"_a=seed
419+
);
420+
421+
if(index_inited == false)
422+
return py::dict( **params, "ef"_a=default_ef);
405423

424+
auto ann_params = getAnnData();
425+
426+
return py::dict(**params, **ann_params);
406427
}
407428

408429

409-
static Index<float> * createFromParams(const py::tuple t) {
410-
411-
if (py::int_(Index<float>::ser_version) != t[0].cast<int>()) // check serialization version
412-
throw std::runtime_error("Serialization version mismatch!");
430+
static Index<float> * createFromParams(const py::dict d) {
413431

414-
py::tuple index_params=t[1].cast<py::tuple>(); /* TODO: convert index_params from py::tuple to py::dict */
415-
py::tuple ann_params=t[2].cast<py::tuple>(); /* TODO: convert ann_params from py::tuple to py::dict */
432+
// check serialization version
433+
assert_true(((int)py::int_(Index<float>::ser_version)) >= d["ser_version"].cast<int>(), "Invalid serialization version!");
416434

417-
auto space_name_=index_params[0].cast<std::string>();
418-
auto dim_=index_params[1].cast<int>();
419-
auto index_inited_=index_params[2].cast<bool>();
435+
auto space_name_=d["space"].cast<std::string>();
436+
auto dim_=d["dim"].cast<int>();
437+
auto index_inited_=d["index_inited"].cast<bool>();
420438

421-
Index<float> *new_index = new Index<float>(index_params[0].cast<std::string>(), index_params[1].cast<int>());
439+
Index<float> *new_index = new Index<float>(space_name_, dim_);
422440

423441
/* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */
424442
/* for full reproducibility / state of generators is serialized inside Index::getIndexParams */
425-
new_index->seed = index_params[6].cast<size_t>();
443+
new_index->seed = d["seed"].cast<size_t>();
426444

427445
if (index_inited_){
428-
new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t>(new_index->l2space, ann_params[1].cast<size_t>(), ann_params[10].cast<size_t>(), ann_params[12].cast<size_t>(), new_index->seed);
429-
new_index->cur_l = ann_params[2].cast<size_t>();
446+
new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t>(new_index->l2space, d["max_elements"].cast<size_t>(), d["M"].cast<size_t>(), d["ef_construction"].cast<size_t>(), new_index->seed);
447+
new_index->cur_l = d["cur_element_count"].cast<size_t>();
430448
}
431449

432450
new_index->index_inited = index_inited_;
433-
new_index->ep_added=index_params[3].cast<bool>();
434-
new_index->num_threads_default=index_params[5].cast<int>();
435-
new_index->default_ef=index_params[7].cast<size_t>();
451+
new_index->ep_added=d["ep_added"].cast<bool>();
452+
new_index->num_threads_default=d["num_threads"].cast<int>();
453+
new_index->default_ef=d["ef"].cast<size_t>();
436454

437455
if (index_inited_)
438-
new_index->setAnnData(ann_params);
439-
456+
new_index->setAnnData(d);
440457

441458
return new_index;
442459
}
443460

444461
static Index<float> * createFromIndex(const Index<float> & index) {
445-
/* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
446-
return createFromParams(index.getIndexParams());
462+
return createFromParams(index.getIndexParams());
447463
}
448464

449-
450-
void setAnnData(const py::tuple t) {
451-
/* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
452-
465+
void setAnnData(const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
466+
467+
453468
std::unique_lock <std::mutex> templock(appr_alg->global);
454469

455-
assert_true(appr_alg->offsetLevel0_ == t[0].cast<size_t>(), "Invalid value of offsetLevel0_ ");
456-
assert_true(appr_alg->max_elements_ == t[1].cast<size_t>(), "Invalid value of max_elements_ ");
470+
assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast<size_t>(), "Invalid value of offsetLevel0_ ");
471+
assert_true(appr_alg->max_elements_ == d["max_elements"].cast<size_t>(), "Invalid value of max_elements_ ");
472+
473+
appr_alg->cur_element_count = d["cur_element_count"].cast<size_t>();
457474

458-
appr_alg->cur_element_count = t[2].cast<size_t>();
475+
assert_true(appr_alg->size_data_per_element_ == d["size_data_per_element"].cast<size_t>(), "Invalid value of size_data_per_element_ ");
476+
assert_true(appr_alg->label_offset_ == d["label_offset"].cast<size_t>(), "Invalid value of label_offset_ ");
477+
assert_true(appr_alg->offsetData_ == d["offset_data"].cast<size_t>(), "Invalid value of offsetData_ ");
459478

460-
assert_true(appr_alg->size_data_per_element_ == t[3].cast<size_t>(), "Invalid value of size_data_per_element_ ");
461-
assert_true(appr_alg->label_offset_ == t[4].cast<size_t>(), "Invalid value of label_offset_ ");
462-
assert_true(appr_alg->offsetData_ == t[5].cast<size_t>(), "Invalid value of offsetData_ ");
479+
appr_alg->maxlevel_ = d["max_level"].cast<int>();
480+
appr_alg->enterpoint_node_ = d["enterpoint_node"].cast<hnswlib::tableint>();
463481

464-
appr_alg->maxlevel_ = t[6].cast<int>();
465-
appr_alg->enterpoint_node_ = t[7].cast<hnswlib::tableint>();
482+
assert_true(appr_alg->maxM_ == d["max_M"].cast<size_t>(), "Invalid value of maxM_ ");
483+
assert_true(appr_alg->maxM0_ == d["max_M0"].cast<size_t>(), "Invalid value of maxM0_ ");
484+
assert_true(appr_alg->M_ == d["M"].cast<size_t>(), "Invalid value of M_ ");
485+
assert_true(appr_alg->mult_ == d["mult"].cast<double>(), "Invalid value of mult_ ");
486+
assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast<size_t>(), "Invalid value of ef_construction_ ");
466487

467-
assert_true(appr_alg->maxM_ == t[8].cast<size_t>(), "Invalid value of maxM_ ");
468-
assert_true(appr_alg->maxM0_ == t[9].cast<size_t>(), "Invalid value of maxM0_ ");
469-
assert_true(appr_alg->M_ == t[10].cast<size_t>(), "Invalid value of M_ ");
470-
assert_true(appr_alg->mult_ == t[11].cast<double>(), "Invalid value of mult_ ");
471-
assert_true(appr_alg->ef_construction_ == t[12].cast<size_t>(), "Invalid value of ef_construction_ ");
488+
appr_alg->ef_ = d["ef"].cast<size_t>();
489+
appr_alg->has_deletions_=d["has_deletions"].cast<bool>();
472490

473-
appr_alg->ef_ = t[13].cast<size_t>();
474-
appr_alg->has_deletions_=t[14].cast<bool>();
475-
assert_true(appr_alg->size_links_per_element_ == t[15].cast<size_t>(), "Invalid value of size_links_per_element_ ");
491+
assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast<size_t>(), "Invalid value of size_links_per_element_ ");
476492

477-
auto label_lookup_key_npy = t[16].cast<py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
478-
auto label_lookup_val_npy = t[17].cast<py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
479-
auto element_levels_npy = t[18].cast<py::array_t < int, py::array::c_style | py::array::forcecast > >();
480-
auto data_level0_npy = t[19].cast<py::array_t < char, py::array::c_style | py::array::forcecast > >();
481-
auto link_list_npy = t[20].cast<py::array_t < char, py::array::c_style | py::array::forcecast > >();
493+
auto label_lookup_key_npy = d["label_lookup_external"].cast<py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
494+
auto label_lookup_val_npy = d["label_lookup_internal"].cast<py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
495+
auto element_levels_npy = d["element_levels"].cast<py::array_t < int, py::array::c_style | py::array::forcecast > >();
496+
auto data_level0_npy = d["data_level0"].cast<py::array_t < char, py::array::c_style | py::array::forcecast > >();
497+
auto link_list_npy = d["link_lists"].cast<py::array_t < char, py::array::c_style | py::array::forcecast > >();
482498

483499
for (size_t i = 0; i < appr_alg->cur_element_count; i++){
484500
if (label_lookup_val_npy.data()[i] < 0){
@@ -516,7 +532,6 @@ class Index {
516532

517533
}
518534
}
519-
520535
}
521536

522537
py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
@@ -640,9 +655,9 @@ PYBIND11_PLUGIN(hnswlib) {
640655
py::module m("hnswlib");
641656

642657
py::class_<Index<float>>(m, "Index")
643-
.def(py::init(&Index<float>::createFromParams), py::arg("params"))
658+
.def(py::init(&Index<float>::createFromParams), py::arg("params"))
644659
/* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */
645-
.def(py::init(&Index<float>::createFromIndex), py::arg("index"))
660+
.def(py::init(&Index<float>::createFromIndex), py::arg("index"))
646661
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
647662
.def("init_index", &Index<float>::init_new_index, py::arg("max_elements"), py::arg("M")=16, py::arg("ef_construction")=200, py::arg("random_seed")=100)
648663
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1)
@@ -682,14 +697,13 @@ PYBIND11_PLUGIN(hnswlib) {
682697

683698
.def(py::pickle(
684699
[](const Index<float> &ind) { // __getstate__
685-
/* Return a tuple that fully encodes the state of the object */
686-
/* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
687-
return ind.getIndexParams();
700+
return py::make_tuple(ind.getIndexParams()); /* Return dict (wrapped in a tuple) that fully encodes state of the Index object */
688701
},
689702
[](py::tuple t) { // __setstate__
690-
if (t.size() != 3)
703+
if (t.size() != 1)
691704
throw std::runtime_error("Invalid state!");
692-
return Index<float>::createFromParams(t);
705+
706+
return Index<float>::createFromParams(t[0].cast<py::dict>());
693707
}
694708
))
695709

0 commit comments

Comments
 (0)