@@ -719,6 +719,7 @@ class BFIndex {
719
719
int dim;
720
720
bool index_inited;
721
721
bool normalize;
722
+ int num_threads_default;
722
723
723
724
hnswlib::labeltype cur_l;
724
725
hnswlib::BruteforceSearch<dist_t >* alg;
@@ -739,6 +740,8 @@ class BFIndex {
739
740
}
740
741
alg = NULL ;
741
742
index_inited = false ;
743
+
744
+ num_threads_default = std::thread::hardware_concurrency ();
742
745
}
743
746
744
747
@@ -749,6 +752,21 @@ class BFIndex {
749
752
}
750
753
751
754
755
+ size_t getMaxElements () const {
756
+ return alg->maxelements_ ;
757
+ }
758
+
759
+
760
+ size_t getCurrentCount () const {
761
+ return alg->cur_element_count ;
762
+ }
763
+
764
+
765
+ void set_num_threads (int num_threads) {
766
+ this ->num_threads_default = num_threads;
767
+ }
768
+
769
+
752
770
void init_new_index (const size_t maxElements) {
753
771
if (alg) {
754
772
throw std::runtime_error (" The index is already initiated." );
@@ -820,15 +838,19 @@ class BFIndex {
820
838
py::object knnQuery_return_numpy (
821
839
py::object input,
822
840
size_t k = 1 ,
841
+ int num_threads = -1 ,
823
842
const std::function<bool (hnswlib::labeltype)>& filter = nullptr) {
824
843
py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
825
844
auto buffer = items.request ();
826
845
hnswlib::labeltype *data_numpy_l;
827
846
dist_t *data_numpy_d;
828
847
size_t rows, features;
848
+
849
+ if (num_threads <= 0 )
850
+ num_threads = num_threads_default;
851
+
829
852
{
830
853
py::gil_scoped_release l;
831
-
832
854
get_input_array_shapes (buffer, &rows, &features);
833
855
834
856
data_numpy_l = new hnswlib::labeltype[rows * k];
@@ -837,16 +859,16 @@ class BFIndex {
837
859
CustomFilterFunctor idFilter (filter);
838
860
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr ;
839
861
840
- for ( size_t row = 0 ; row < rows; row++ ) {
862
+ ParallelFor ( 0 , rows, num_threads, [&]( size_t row, size_t threadId ) {
841
863
std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
842
- (void *) items.data (row), k, p_idFilter);
864
+ (void *) items.data (row), k, p_idFilter);
843
865
for (int i = k - 1 ; i >= 0 ; i--) {
844
- auto & result_tuple = result.top ();
866
+ auto & result_tuple = result.top ();
845
867
data_numpy_d[row * k + i] = result_tuple.first ;
846
868
data_numpy_l[row * k + i] = result_tuple.second ;
847
869
result.pop ();
848
870
}
849
- }
871
+ });
850
872
}
851
873
852
874
py::capsule free_when_done_l (data_numpy_l, [](void *f) {
@@ -957,13 +979,22 @@ PYBIND11_PLUGIN(hnswlib) {
957
979
py::class_<BFIndex<float >>(m, " BFIndex" )
958
980
.def (py::init<const std::string &, const int >(), py::arg (" space" ), py::arg (" dim" ))
959
981
.def (" init_index" , &BFIndex<float >::init_new_index, py::arg (" max_elements" ))
960
- .def (" knn_query" , &BFIndex<float >::knnQuery_return_numpy, py::arg (" data" ), py::arg (" k" ) = 1 , py::arg (" filter" ) = py::none ())
982
+ .def (" knn_query" ,
983
+ &BFIndex<float >::knnQuery_return_numpy,
984
+ py::arg (" data" ),
985
+ py::arg (" k" ) = 1 ,
986
+ py::arg (" num_threads" ) = -1 ,
987
+ py::arg (" filter" ) = py::none ())
961
988
.def (" add_items" , &BFIndex<float >::addItems, py::arg (" data" ), py::arg (" ids" ) = py::none ())
962
989
.def (" delete_vector" , &BFIndex<float >::deleteVector, py::arg (" label" ))
990
+ .def (" set_num_threads" , &BFIndex<float >::set_num_threads, py::arg (" num_threads" ))
963
991
.def (" save_index" , &BFIndex<float >::saveIndex, py::arg (" path_to_index" ))
964
992
.def (" load_index" , &BFIndex<float >::loadIndex, py::arg (" path_to_index" ), py::arg (" max_elements" ) = 0 )
965
993
.def (" __repr__" , [](const BFIndex<float > &a) {
966
994
return " <hnswlib.BFIndex(space='" + a.space_name + " ', dim=" +std::to_string (a.dim )+" )>" ;
967
- });
995
+ })
996
+ .def (" get_max_elements" , &BFIndex<float >::getMaxElements)
997
+ .def (" get_current_count" , &BFIndex<float >::getCurrentCount)
998
+ .def_readwrite (" num_threads" , &BFIndex<float >::num_threads_default);
968
999
return m.ptr ();
969
1000
}
0 commit comments