9
9
#include < assert.h>
10
10
11
11
namespace py = pybind11;
12
+ using namespace pybind11 ::literals; // needed to bring in _a literal
12
13
13
14
/*
14
15
* replacement for the openmp '#pragma omp parallel for' directive
@@ -73,6 +74,12 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
73
74
74
75
}
75
76
77
+ inline void assert_true (bool expr, const std::string & msg) {
78
+ if (expr == false )
79
+ throw std::runtime_error (" Unpickle Error: " +msg);
80
+ return ;
81
+ }
82
+
76
83
77
84
78
85
template <typename dist_t , typename data_t =float >
@@ -98,7 +105,7 @@ class Index {
98
105
99
106
default_ef=10 ;
100
107
}
101
-
108
+
102
109
static const int ser_version = 1 ; // serialization version
103
110
104
111
std::string space_name;
@@ -278,15 +285,11 @@ class Index {
278
285
return ids;
279
286
}
280
287
281
- inline void assert_true (bool expr, const std::string & msg) {
282
- if (expr == false )
283
- throw std::runtime_error (" assert failed: " +msg);
284
- return ;
285
- }
288
+
289
+ py::dict getAnnData () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
290
+
286
291
287
292
288
- py::tuple getAnnData () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
289
-
290
293
std::unique_lock <std::mutex> templock (appr_alg->global );
291
294
292
295
unsigned int level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_ ;
@@ -345,140 +348,153 @@ class Index {
345
348
delete[] f;
346
349
});
347
350
348
- return py::make_tuple (appr_alg->offsetLevel0_ ,
349
- appr_alg->max_elements_ ,
350
- appr_alg->cur_element_count ,
351
- appr_alg->size_data_per_element_ ,
352
- appr_alg->label_offset_ ,
353
- appr_alg->offsetData_ ,
354
- appr_alg->maxlevel_ ,
355
- appr_alg->enterpoint_node_ ,
356
- appr_alg->maxM_ ,
357
- appr_alg->maxM0_ ,
358
- appr_alg->M_ ,
359
- appr_alg->mult_ ,
360
- appr_alg->ef_construction_ ,
361
- appr_alg->ef_ ,
362
- appr_alg->has_deletions_ ,
363
- appr_alg->size_links_per_element_ ,
364
- py::array_t <hnswlib::labeltype>(
365
- {appr_alg->label_lookup_ .size ()}, // shape
366
- {sizeof (hnswlib::labeltype)}, // C-style contiguous strides for double
367
- label_lookup_key_npy, // the data pointer
368
- free_when_done_lb),
369
- py::array_t <hnswlib::tableint>(
370
- {appr_alg->label_lookup_ .size ()}, // shape
371
- {sizeof (hnswlib::tableint)}, // C-style contiguous strides for double
372
- label_lookup_val_npy, // the data pointer
373
- free_when_done_id),
374
- py::array_t <int >(
375
- {appr_alg->element_levels_ .size ()}, // shape
376
- {sizeof (int )}, // C-style contiguous strides for double
377
- element_levels_npy, // the data pointer
378
- free_when_done_lvl),
379
- py::array_t <char >(
380
- {level0_npy_size}, // shape
381
- {sizeof (char )}, // C-style contiguous strides for double
382
- data_level0_npy, // the data pointer
383
- free_when_done_l0),
384
- py::array_t <char >(
385
- {link_npy_size}, // shape
386
- {sizeof (char )}, // C-style contiguous strides for double
387
- link_list_npy, // the data pointer
388
- free_when_done_ll)
389
- );
351
+ /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
352
+ /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
353
+
354
+ return py::dict (
355
+ " offset_level0" _a=appr_alg->offsetLevel0_ ,
356
+ " max_elements" _a=appr_alg->max_elements_ ,
357
+ " cur_element_count" _a=appr_alg->cur_element_count ,
358
+ " size_data_per_element" _a=appr_alg->size_data_per_element_ ,
359
+ " label_offset" _a=appr_alg->label_offset_ ,
360
+ " offset_data" _a=appr_alg->offsetData_ ,
361
+ " max_level" _a=appr_alg->maxlevel_ ,
362
+ " enterpoint_node" _a=appr_alg->enterpoint_node_ ,
363
+ " max_M" _a=appr_alg->maxM_ ,
364
+ " max_M0" _a=appr_alg->maxM0_ ,
365
+ " M" _a=appr_alg->M_ ,
366
+ " mult" _a=appr_alg->mult_ ,
367
+ " ef_construction" _a=appr_alg->ef_construction_ ,
368
+ " ef" _a=appr_alg->ef_ ,
369
+ " has_deletions" _a=appr_alg->has_deletions_ ,
370
+ " size_links_per_element" _a=appr_alg->size_links_per_element_ ,
371
+
372
+ " label_lookup_external" _a=py::array_t <hnswlib::labeltype>(
373
+ {appr_alg->label_lookup_ .size ()}, // shape
374
+ {sizeof (hnswlib::labeltype)}, // C-style contiguous strides for double
375
+ label_lookup_key_npy, // the data pointer
376
+ free_when_done_lb),
377
+
378
+ " label_lookup_internal" _a=py::array_t <hnswlib::tableint>(
379
+ {appr_alg->label_lookup_ .size ()}, // shape
380
+ {sizeof (hnswlib::tableint)}, // C-style contiguous strides for double
381
+ label_lookup_val_npy, // the data pointer
382
+ free_when_done_id),
383
+
384
+ " element_levels" _a=py::array_t <int >(
385
+ {appr_alg->element_levels_ .size ()}, // shape
386
+ {sizeof (int )}, // C-style contiguous strides for double
387
+ element_levels_npy, // the data pointer
388
+ free_when_done_lvl),
389
+
390
+ // linkLists_,element_levels_,data_level0_memory_
391
+ " data_level0" _a=py::array_t <char >(
392
+ {level0_npy_size}, // shape
393
+ {sizeof (char )}, // C-style contiguous strides for double
394
+ data_level0_npy, // the data pointer
395
+ free_when_done_l0),
396
+
397
+ " link_lists" _a=py::array_t <char >(
398
+ {link_npy_size}, // shape
399
+ {sizeof (char )}, // C-style contiguous strides for double
400
+ link_list_npy, // the data pointer
401
+ free_when_done_ll)
402
+
403
+ );
404
+
390
405
391
406
}
392
407
393
408
394
- py::tuple getIndexParams () const {
395
- /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
396
- /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
397
-
398
- return py::make_tuple (py::int_ (Index<float >::ser_version), // serialization version
399
-
400
- /* TODO: convert the following two py::tuple's to py::dict */
401
- py::make_tuple (space_name, dim, index_inited, ep_added, normalize, num_threads_default, seed, default_ef),
402
- index_inited == true ? getAnnData () : py::make_tuple ()); /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
403
-
404
-
409
+ py::dict getIndexParams () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
410
+ auto params = py::dict (
411
+ " ser_version" _a=py::int_ (Index<float >::ser_version), // serialization version
412
+ " space" _a=space_name,
413
+ " dim" _a=dim,
414
+ " index_inited" _a=index_inited,
415
+ " ep_added" _a=ep_added,
416
+ " normalize" _a=normalize,
417
+ " num_threads" _a=num_threads_default,
418
+ " seed" _a=seed
419
+ );
420
+
421
+ if (index_inited == false )
422
+ return py::dict ( **params, " ef" _a=default_ef);
405
423
424
+ auto ann_params = getAnnData ();
425
+
426
+ return py::dict (**params, **ann_params);
406
427
}
407
428
408
429
409
- static Index<float > * createFromParams (const py::tuple t) {
410
-
411
- if (py::int_ (Index<float >::ser_version) != t[0 ].cast <int >()) // check serialization version
412
- throw std::runtime_error (" Serialization version mismatch!" );
430
+ static Index<float > * createFromParams (const py::dict d) {
413
431
414
- py::tuple index_params=t[ 1 ]. cast <py::tuple>(); /* TODO: convert index_params from py::tuple to py::dict */
415
- py::tuple ann_params=t[ 2 ].cast <py::tuple >(); /* TODO: convert ann_params from py::tuple to py::dict */
432
+ // check serialization version
433
+ assert_true ((( int ) py::int_ (Index< float >::ser_version)) >= d[ " ser_version " ].cast <int >(), " Invalid serialization version! " );
416
434
417
- auto space_name_=index_params[ 0 ].cast <std::string>();
418
- auto dim_=index_params[ 1 ].cast <int >();
419
- auto index_inited_=index_params[ 2 ].cast <bool >();
435
+ auto space_name_=d[ " space " ].cast <std::string>();
436
+ auto dim_=d[ " dim " ].cast <int >();
437
+ auto index_inited_=d[ " index_inited " ].cast <bool >();
420
438
421
- Index<float > *new_index = new Index<float >(index_params[ 0 ]. cast <std::string>(), index_params[ 1 ]. cast < int >() );
439
+ Index<float > *new_index = new Index<float >(space_name_, dim_ );
422
440
423
441
/* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */
424
442
/* for full reproducibility / state of generators is serialized inside Index::getIndexParams */
425
- new_index->seed = index_params[ 6 ].cast <size_t >();
443
+ new_index->seed = d[ " seed " ].cast <size_t >();
426
444
427
445
if (index_inited_){
428
- new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t >(new_index->l2space , ann_params[ 1 ].cast <size_t >(), ann_params[ 10 ].cast <size_t >(), ann_params[ 12 ].cast <size_t >(), new_index->seed );
429
- new_index->cur_l = ann_params[ 2 ].cast <size_t >();
446
+ new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t >(new_index->l2space , d[ " max_elements " ].cast <size_t >(), d[ " M " ].cast <size_t >(), d[ " ef_construction " ].cast <size_t >(), new_index->seed );
447
+ new_index->cur_l = d[ " cur_element_count " ].cast <size_t >();
430
448
}
431
449
432
450
new_index->index_inited = index_inited_;
433
- new_index->ep_added =index_params[ 3 ].cast <bool >();
434
- new_index->num_threads_default =index_params[ 5 ].cast <int >();
435
- new_index->default_ef =index_params[ 7 ].cast <size_t >();
451
+ new_index->ep_added =d[ " ep_added " ].cast <bool >();
452
+ new_index->num_threads_default =d[ " num_threads " ].cast <int >();
453
+ new_index->default_ef =d[ " ef " ].cast <size_t >();
436
454
437
455
if (index_inited_)
438
- new_index->setAnnData (ann_params);
439
-
456
+ new_index->setAnnData (d);
440
457
441
458
return new_index;
442
459
}
443
460
444
461
static Index<float > * createFromIndex (const Index<float > & index) {
445
- /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
446
- return createFromParams (index .getIndexParams ());
462
+ return createFromParams (index .getIndexParams ());
447
463
}
448
464
449
-
450
- void setAnnData (const py::tuple t) {
451
- /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
452
-
465
+ void setAnnData (const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
466
+
467
+
453
468
std::unique_lock <std::mutex> templock (appr_alg->global );
454
469
455
- assert_true (appr_alg->offsetLevel0_ == t[0 ].cast <size_t >(), " Invalid value of offsetLevel0_ " );
456
- assert_true (appr_alg->max_elements_ == t[1 ].cast <size_t >(), " Invalid value of max_elements_ " );
470
+ assert_true (appr_alg->offsetLevel0_ == d[" offset_level0" ].cast <size_t >(), " Invalid value of offsetLevel0_ " );
471
+ assert_true (appr_alg->max_elements_ == d[" max_elements" ].cast <size_t >(), " Invalid value of max_elements_ " );
472
+
473
+ appr_alg->cur_element_count = d[" cur_element_count" ].cast <size_t >();
457
474
458
- appr_alg->cur_element_count = t[2 ].cast <size_t >();
475
+ assert_true (appr_alg->size_data_per_element_ == d[" size_data_per_element" ].cast <size_t >(), " Invalid value of size_data_per_element_ " );
476
+ assert_true (appr_alg->label_offset_ == d[" label_offset" ].cast <size_t >(), " Invalid value of label_offset_ " );
477
+ assert_true (appr_alg->offsetData_ == d[" offset_data" ].cast <size_t >(), " Invalid value of offsetData_ " );
459
478
460
- assert_true (appr_alg->size_data_per_element_ == t[3 ].cast <size_t >(), " Invalid value of size_data_per_element_ " );
461
- assert_true (appr_alg->label_offset_ == t[4 ].cast <size_t >(), " Invalid value of label_offset_ " );
462
- assert_true (appr_alg->offsetData_ == t[5 ].cast <size_t >(), " Invalid value of offsetData_ " );
479
+ appr_alg->maxlevel_ = d[" max_level" ].cast <int >();
480
+ appr_alg->enterpoint_node_ = d[" enterpoint_node" ].cast <hnswlib::tableint>();
463
481
464
- appr_alg->maxlevel_ = t[6 ].cast <int >();
465
- appr_alg->enterpoint_node_ = t[7 ].cast <hnswlib::tableint>();
482
+ assert_true (appr_alg->maxM_ == d[" max_M" ].cast <size_t >(), " Invalid value of maxM_ " );
483
+ assert_true (appr_alg->maxM0_ == d[" max_M0" ].cast <size_t >(), " Invalid value of maxM0_ " );
484
+ assert_true (appr_alg->M_ == d[" M" ].cast <size_t >(), " Invalid value of M_ " );
485
+ assert_true (appr_alg->mult_ == d[" mult" ].cast <double >(), " Invalid value of mult_ " );
486
+ assert_true (appr_alg->ef_construction_ == d[" ef_construction" ].cast <size_t >(), " Invalid value of ef_construction_ " );
466
487
467
- assert_true (appr_alg->maxM_ == t[8 ].cast <size_t >(), " Invalid value of maxM_ " );
468
- assert_true (appr_alg->maxM0_ == t[9 ].cast <size_t >(), " Invalid value of maxM0_ " );
469
- assert_true (appr_alg->M_ == t[10 ].cast <size_t >(), " Invalid value of M_ " );
470
- assert_true (appr_alg->mult_ == t[11 ].cast <double >(), " Invalid value of mult_ " );
471
- assert_true (appr_alg->ef_construction_ == t[12 ].cast <size_t >(), " Invalid value of ef_construction_ " );
488
+ appr_alg->ef_ = d[" ef" ].cast <size_t >();
489
+ appr_alg->has_deletions_ =d[" has_deletions" ].cast <bool >();
472
490
473
- appr_alg->ef_ = t[13 ].cast <size_t >();
474
- appr_alg->has_deletions_ =t[14 ].cast <bool >();
475
- assert_true (appr_alg->size_links_per_element_ == t[15 ].cast <size_t >(), " Invalid value of size_links_per_element_ " );
491
+ assert_true (appr_alg->size_links_per_element_ == d[" size_links_per_element" ].cast <size_t >(), " Invalid value of size_links_per_element_ " );
476
492
477
- auto label_lookup_key_npy = t[ 16 ].cast <py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
478
- auto label_lookup_val_npy = t[ 17 ].cast <py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
479
- auto element_levels_npy = t[ 18 ].cast <py::array_t < int , py::array::c_style | py::array::forcecast > >();
480
- auto data_level0_npy = t[ 19 ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
481
- auto link_list_npy = t[ 20 ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
493
+ auto label_lookup_key_npy = d[ " label_lookup_external " ].cast <py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
494
+ auto label_lookup_val_npy = d[ " label_lookup_internal " ].cast <py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
495
+ auto element_levels_npy = d[ " element_levels " ].cast <py::array_t < int , py::array::c_style | py::array::forcecast > >();
496
+ auto data_level0_npy = d[ " data_level0 " ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
497
+ auto link_list_npy = d[ " link_lists " ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
482
498
483
499
for (size_t i = 0 ; i < appr_alg->cur_element_count ; i++){
484
500
if (label_lookup_val_npy.data ()[i] < 0 ){
@@ -516,7 +532,6 @@ class Index {
516
532
517
533
}
518
534
}
519
-
520
535
}
521
536
522
537
py::object knnQuery_return_numpy (py::object input, size_t k = 1 , int num_threads = -1 ) {
@@ -640,9 +655,9 @@ PYBIND11_PLUGIN(hnswlib) {
640
655
py::module m (" hnswlib" );
641
656
642
657
py::class_<Index<float >>(m, " Index" )
643
- .def (py::init (&Index<float >::createFromParams), py::arg (" params" ))
658
+ .def (py::init (&Index<float >::createFromParams), py::arg (" params" ))
644
659
/* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */
645
- .def (py::init (&Index<float >::createFromIndex), py::arg (" index" ))
660
+ .def (py::init (&Index<float >::createFromIndex), py::arg (" index" ))
646
661
.def (py::init<const std::string &, const int >(), py::arg (" space" ), py::arg (" dim" ))
647
662
.def (" init_index" , &Index<float >::init_new_index, py::arg (" max_elements" ), py::arg (" M" )=16 , py::arg (" ef_construction" )=200 , py::arg (" random_seed" )=100 )
648
663
.def (" knn_query" , &Index<float >::knnQuery_return_numpy, py::arg (" data" ), py::arg (" k" )=1 , py::arg (" num_threads" )=-1 )
@@ -682,14 +697,13 @@ PYBIND11_PLUGIN(hnswlib) {
682
697
683
698
.def (py::pickle (
684
699
[](const Index<float > &ind) { // __getstate__
685
- /* Return a tuple that fully encodes the state of the object */
686
- /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
687
- return ind.getIndexParams ();
700
+ return py::make_tuple (ind.getIndexParams ()); /* Return dict (wrapped in a tuple) that fully encodes state of the Index object */
688
701
},
689
702
[](py::tuple t) { // __setstate__
690
- if (t.size () != 3 )
703
+ if (t.size () != 1 )
691
704
throw std::runtime_error (" Invalid state!" );
692
- return Index<float >::createFromParams (t);
705
+
706
+ return Index<float >::createFromParams (t[0 ].cast <py::dict>());
693
707
}
694
708
))
695
709
0 commit comments