From fb6c1f757208d76fd387ec046cbad6c3f978a3cc Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Fri, 25 Jul 2025 19:17:35 +0530 Subject: [PATCH] [ENH]: Added functionality to use a in-memory intermediary to load --- CMakeLists.txt | 2 +- Cargo.lock | 2 +- hnswlib/hnswalg.h | 425 ++++++++++++++++++++++++++++++---- src/bindings.cpp | 157 ++++++++++++- src/hnsw.rs | 364 ++++++++++++++++++++++++++++- tests/cpp/persistent_test.cpp | 60 +++++ 6 files changed, 964 insertions(+), 46 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c292a59a..b60eeb05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required (VERSION 2.6) +cmake_minimum_required (VERSION 3.5) project(hnsw_lib LANGUAGES CXX) diff --git a/Cargo.lock b/Cargo.lock index ecc83867..648b4f3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,7 +101,7 @@ dependencies = [ [[package]] name = "hnswlib" -version = "0.8.0" +version = "0.8.1" dependencies = [ "cc", "rand", diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e4e8150f..48a2c586 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -5,13 +5,184 @@ #include #include #include +#include +#include #include +#include +#include #include #include -#include +#include namespace hnswlib { + #define HEADER_FIELDS(ACTION) \ + ACTION(PERSISTENCE_VERSION) \ + ACTION(offsetLevel0_) \ + ACTION(max_elements_) \ + ACTION(cur_element_count) \ + ACTION(size_data_per_element_) \ + ACTION(label_offset_) \ + ACTION(offsetData_) \ + ACTION(maxlevel_) \ + ACTION(enterpoint_node_) \ + ACTION(maxM_) \ + ACTION(maxM0_) \ + ACTION(M_) \ + ACTION(mult_) \ + ACTION(ef_construction_) + + // --- Input Streambuf (read from memory) --- + class in_membuf : public std::streambuf { + public: + in_membuf(const char* data, std::size_t size) { + char* p = const_cast(data); + setg(p, p, p + size); // set get pointers + } + }; + + // --- Input Stream --- + class memistream : public std::istream { + public: + memistream(const char* data, std::size_t size) + : std::istream(&buf), buf(data, size) {} + + private: + in_membuf buf; + }; + + struct InputPersistenceStreams { + std::shared_ptr header_stream; + std::shared_ptr data_level0_stream; + std::shared_ptr length_stream; + std::shared_ptr link_list_stream; + }; + + // --- Memory Buffer Container for Persistence --- + struct HnswData { + bool own_buffers; + char* header_buffer; + size_t header_size; + char* data_level0_buffer; + size_t data_level0_size; + char* length_buffer; + size_t length_size; + char* link_list_buffer; + size_t link_list_size; + + HnswData(bool own_buffers) : own_buffers(own_buffers), header_buffer(nullptr), header_size(0), + data_level0_buffer(nullptr), data_level0_size(0), + length_buffer(nullptr), length_size(0), + link_list_buffer(nullptr), link_list_size(0) {} + + ~HnswData() { if (own_buffers) free_buffers(); } + + void free_buffers() { + if (header_buffer) { + free(header_buffer); + header_buffer = nullptr; + header_size = 0; + } + if (data_level0_buffer) { + free(data_level0_buffer); + data_level0_buffer = nullptr; + data_level0_size = 0; + } + if (length_buffer) { + free(length_buffer); + length_buffer = nullptr; + length_size = 0; + } + if (link_list_buffer) { + free(link_list_buffer); + link_list_buffer = nullptr; + link_list_size = 0; + } + } + + // Allocate buffers with specified sizes + void allocate_buffers(size_t header_sz, size_t data_sz, size_t length_sz, size_t link_sz) { + free_buffers(); // Free any existing buffers first + + header_buffer = (char*)malloc(header_sz); + if (!header_buffer) throw std::runtime_error("Failed to allocate header buffer"); + header_size = header_sz; + + data_level0_buffer = (char*)malloc(data_sz); + if (!data_level0_buffer) throw std::runtime_error("Failed to allocate data_level0 buffer"); + data_level0_size = data_sz; + + length_buffer = (char*)malloc(length_sz); + if (!length_buffer) throw std::runtime_error("Failed to allocate length buffer"); + length_size = length_sz; + + link_list_buffer = (char*)malloc(link_sz); + if (!link_list_buffer) throw std::runtime_error("Failed to allocate link_list buffer"); + link_list_size = link_sz; + } + + + // Only for testing + bool matchesWithDirectory(const std::string& directory) const { + if (header_buffer == nullptr || data_level0_buffer == nullptr || length_buffer == nullptr || link_list_buffer == nullptr) { + printf("HnswData is not initialized\n"); + return false; + } + + struct file_test_info { + const char *filename; + const char *buffer_ptr; + size_t buffer_size; + }; + + file_test_info files[] = { + {"header", header_buffer, header_size}, + {"data_level0", data_level0_buffer, data_level0_size}, + {"length", length_buffer, length_size}, + {"link_lists", link_list_buffer, link_list_size} + }; + + for (const auto& file : files) { + printf("testing %s\n", file.filename); + std::ifstream file_stream(directory + "/" + file.filename + ".bin"); + if (!file_stream.is_open()) { + printf("File %s not found\n", file.filename); + return false; + } + + file_stream.seekg(0, file_stream.end); + size_t file_size = file_stream.tellg(); + file_stream.seekg(0, file_stream.beg); + + if (file_size != file.buffer_size) { + printf("File %s size mismatch %ld != %ld\n", file.filename, file_size, file.buffer_size); + return false; + } + + std::vector file_content(file_size); + file_stream.read(file_content.data(), file_size); + file_stream.close(); + + if (file_content != std::vector(file.buffer_ptr, file.buffer_ptr + file.buffer_size)) { + printf("File %s content mismatch\n", file.filename); + printf("File content:\n"); + for (size_t i = 0; i < file_size; i++) { + printf("%02x ", file_content[i]); + } + printf("\n"); + printf("Buffer content:\n"); + for (size_t i = 0; i < file.buffer_size; i++) { + printf("%02x ", file.buffer_ptr[i]); + } + printf("\n"); + return false; + } + } + + return true; + } + }; + typedef unsigned int tableint; typedef unsigned int linklistsizeint; const int PERSISTENCE_VERSION = 1; // Used by persistent indices to check if the index on disk is compatible with the code @@ -114,6 +285,21 @@ namespace hnswlib } } + HierarchicalNSW( + SpaceInterface *s, + const HnswData *buffers, + bool nmslib = false, + size_t max_elements = 0, + bool allow_replace_deleted = false, + bool normalize = false) + : allow_replace_deleted_(allow_replace_deleted), + normalize_(normalize), + persist_on_write_(false), + persist_location_("") + { + loadPersistedIndexFromMemory(s, buffers, max_elements); + } + HierarchicalNSW( SpaceInterface *s, size_t max_elements, @@ -732,6 +918,8 @@ namespace hnswlib std::ofstream output(location, std::ios::binary); std::streampos position; + // IF THIS IS CHANGED: PLEASE MAKE CORRESPONDING MODIFICATIONS TO + // THE HEADER_FIELDS MACRO. writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); @@ -879,29 +1067,13 @@ namespace hnswlib openPersistentIndex(); } - void persistHeader(std::ofstream &output_header) + void persistHeader(std::ostream &output_header) { - if (!persist_on_write_) - { - throw std::runtime_error("persistHeader called for an index that is not set to persist on write"); - } - output_header.seekp(0, std::ios::beg); - writeBinaryPOD(output_header, PERSISTENCE_VERSION); - writeBinaryPOD(output_header, offsetLevel0_); - writeBinaryPOD(output_header, max_elements_); - writeBinaryPOD(output_header, cur_element_count); // needs to be updated - writeBinaryPOD(output_header, size_data_per_element_); - writeBinaryPOD(output_header, label_offset_); - writeBinaryPOD(output_header, offsetData_); - writeBinaryPOD(output_header, maxlevel_); // needs to be updated - writeBinaryPOD(output_header, enterpoint_node_); // does this need to be updated? - writeBinaryPOD(output_header, maxM_); - - writeBinaryPOD(output_header, maxM0_); - writeBinaryPOD(output_header, M_); - writeBinaryPOD(output_header, mult_); // does this need to be updated? - writeBinaryPOD(output_header, ef_construction_); + + #define WRITE_ACTION(field) writeBinaryPOD(output_header, field); + HEADER_FIELDS(WRITE_ACTION) + #undef WRITE_ACTION output_header.flush(); } @@ -974,11 +1146,99 @@ namespace hnswlib elements_to_persist_.clear(); } - void loadPersistedIndex(SpaceInterface *s, size_t max_elements_i = 0) + constexpr size_t calculateHeaderSize() const { + #define SIZE_ACTION(field) + sizeof(field) + return 0 HEADER_FIELDS(SIZE_ACTION); + #undef SIZE_ACTION + } + + size_t calculateLinkListSize() { + size_t total_size = 0; + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + total_size += sizeof(unsigned int) + linkListSize; + } + return total_size; + } + + size_t calculateDataLevel0Size() { + return max_elements_ * size_data_per_element_; + } + + size_t calculateLengthSize() { + return max_elements_ * sizeof(float); + } + + void serializeHeaderToBuffer(char* buffer, size_t buffer_size) { + char *init_buffer = buffer; + #define WRITE_ACTION(field) \ + do { memcpy(buffer, &(field), sizeof(field)); buffer += sizeof(field); } while (0); + HEADER_FIELDS(WRITE_ACTION) + #undef WRITE_ACTION + } + + void serializeDataLevel0ToBuffer(char* buffer, size_t buffer_size) { + memcpy(buffer, data_level0_memory_, max_elements_ * size_data_per_element_); + } + + void serializeLengthToBuffer(char* buffer, size_t buffer_size) { + memcpy(buffer, length_memory_, max_elements_ * sizeof(float)); + } + + void serializeLinkListsToBuffer(char* buffer, size_t buffer_size) { + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + memcpy(buffer, &linkListSize, sizeof(unsigned int)); + buffer += sizeof(unsigned int); + if (linkListSize) { + memcpy(buffer, linkLists_[i], linkListSize); + buffer += linkListSize; + } + } + } + + // Memory-based persistence functions + HnswData* serializeToHnswData() { + // Calculate buffer sizes + size_t header_size = calculateHeaderSize(); + size_t data_level0_size = calculateDataLevel0Size(); + size_t length_size = calculateLengthSize(); + size_t link_list_size = calculateLinkListSize(); + + // Create and allocate memory buffers + HnswData* data = createHnswData(); + data->allocate_buffers(calculateHeaderSize(), calculateDataLevel0Size(), calculateLengthSize(), calculateLinkListSize()); + + // Serialize to memory buffers + serializeHeaderToBuffer(data->header_buffer, data->header_size); + serializeDataLevel0ToBuffer(data->data_level0_buffer, data->data_level0_size); + serializeLengthToBuffer(data->length_buffer, data->length_size); + serializeLinkListsToBuffer(data->link_list_buffer, data->link_list_size); + + return data; + } + + void saveToHnswData(HnswData** data_ptr) { + if (!data_ptr) { + throw std::runtime_error("saveToHnswData called with null data"); + } + + + HnswData *data = serializeToHnswData(); + *data_ptr = data; + } + + void readPersistedIndexFromStreams(SpaceInterface *s, + InputPersistenceStreams& input_streams, + size_t max_elements_i = 0) { - std::ifstream input_header(this->getHeaderLocation(), std::ios::binary); - if (!input_header.is_open()) - throw std::runtime_error("Cannot open header file"); + auto& input_header = *input_streams.header_stream; + auto& input_data_level0 = *input_streams.data_level0_stream; + auto& input_length = *input_streams.length_stream; + auto& input_link_list = *input_streams.link_list_stream; + + if (!input_header.good()) + throw std::runtime_error("Header stream is not in good state"); // Read the header int persisted_version; @@ -1006,49 +1266,134 @@ namespace hnswlib readBinaryPOD(input_header, M_); readBinaryPOD(input_header, mult_); readBinaryPOD(input_header, ef_construction_); - input_header.close(); data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); // Read data_level0_memory_ - std::ifstream input_data_level0(this->getDataLevel0Location(), std::ios::binary); - if (!input_data_level0.is_open()) - throw std::runtime_error("Cannot open data_level0 file"); + if (!input_data_level0.good()) + throw std::runtime_error("Data level0 stream is not in good state"); data_level0_memory_ = (char *)malloc(max_elements * size_data_per_element_); if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadPersistedIndex failed to allocate level0"); input_data_level0.read(data_level0_memory_, max_elements * size_data_per_element_); - input_data_level0.close(); // Read length_memory_ - std::ifstream input_length(this->getLengthLocation(), std::ios::binary); - if (!input_length.is_open()) - throw std::runtime_error("Cannot open length file"); + if (!input_length.good()) + throw std::runtime_error("Length stream is not in good state"); length_memory_ = (char *)malloc(max_elements * sizeof(float)); if (length_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadPersistedIndex failed to allocate length_memory_"); input_length.read(length_memory_, max_elements * sizeof(float)); - input_length.close(); // Read the linkLists + if (!input_link_list.good()) + throw std::runtime_error("Link list stream is not in good state"); + loadLinkLists(input_link_list); + + loadDeleted(); + return; + } + + void loadPersistedIndex(SpaceInterface *s, size_t max_elements_i = 0) + { + std::ifstream input_header(this->getHeaderLocation(), std::ios::binary); + if (!input_header.is_open()) + throw std::runtime_error("Cannot open header file"); + + std::ifstream input_data_level0(this->getDataLevel0Location(), std::ios::binary); + if (!input_data_level0.is_open()) + throw std::runtime_error("Cannot open data_level0 file"); + + std::ifstream input_length(this->getLengthLocation(), std::ios::binary); + if (!input_length.is_open()) + throw std::runtime_error("Cannot open length file"); + std::ifstream input_link_list(this->getLinkListLocation(), std::ios::binary); if (!input_link_list.is_open()) throw std::runtime_error("Cannot open link list file"); - loadLinkLists(input_link_list); - input_link_list.close(); - loadDeleted(); + { + InputPersistenceStreams input_streams = { + std::make_shared(std::move(input_header)), + std::make_shared(std::move(input_data_level0)), + std::make_shared(std::move(input_length)), + std::make_shared(std::move(input_link_list)) + }; + + readPersistedIndexFromStreams(s, input_streams, max_elements_i); + } openPersistentIndex(); - return; } + + void loadPersistedIndexFromMemory(SpaceInterface *s, + const HnswData *buffers, + size_t max_elements_i = 0) + { + // This function expects null-terminated C-style strings as per the memory + // Create streams from the extracted data + + InputPersistenceStreams input_streams = { + std::make_shared(buffers->header_buffer, buffers->header_size), + std::make_shared(buffers->data_level0_buffer, buffers->data_level0_size), + std::make_shared(buffers->length_buffer, buffers->length_size), + std::make_shared(buffers->link_list_buffer, buffers->link_list_size) + }; + + readPersistedIndexFromStreams(s, input_streams, max_elements_i); + } + + // C-style functions for Rust FFI integration + HnswData* createHnswData() { + return new HnswData(true /* own_buffers */); + } + + void freeHnswData(HnswData* data) { + if (data) { + delete data; + } + } + + // Get buffer pointers and sizes for Rust FFI + char* getHeaderBuffer(HnswData* data) { + return data ? data->header_buffer : nullptr; + } + + size_t getHeaderBufferSize(HnswData* data) { + return data ? data->header_size : 0; + } + + char* getDataLevel0Buffer(HnswData* data) { + return data ? data->data_level0_buffer : nullptr; + } + + size_t getDataLevel0BufferSize(HnswData* data) { + return data ? data->data_level0_size : 0; + } + + char* getLengthBuffer(HnswData* data) { + return data ? data->length_buffer : nullptr; + } + + size_t getLengthBufferSize(HnswData* data) { + return data ? data->length_size : 0; + } + + char* getLinkListBuffer(HnswData* data) { + return data ? data->link_list_buffer : nullptr; + } + + size_t getLinkListBufferSize(HnswData* data) { + return data ? data->link_list_size : 0; + } + // #pragma endregion - void loadLinkLists(std::ifstream &input_link_list) + void loadLinkLists(std::istream &input_link_list) { // Init link lists / visited lists pool size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); diff --git a/src/bindings.cpp b/src/bindings.cpp index 92e1c4ac..33dd96f9 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -1,6 +1,5 @@ // Assumes that chroma-hnswlib is checked out at the same level as chroma #include "../hnswlib/hnswlib.h" -#include class AllowAndDisallowListFilterFunctor : public hnswlib::BaseFilterFunctor { @@ -99,6 +98,32 @@ class Index index_inited = true; } + void load_index_from_buffers(const char *header_buffer, size_t header_len, + const char *data_buffer, size_t data_len, + const char *index_buffer, size_t index_len, + const char *deleted_buffer, size_t deleted_len, + const bool allow_replace_deleted, + const bool normalize, + const size_t max_elements) + { + if (index_inited) + { + throw std::runtime_error("Index already inited"); + } + + std::string header_str(header_buffer, header_len); + std::string data_str(data_buffer, data_len); + std::string index_str(index_buffer, index_len); + std::string deleted_str(deleted_buffer, deleted_len); + + appr_alg = new hnswlib::HierarchicalNSW(l2space, header_str.c_str(), data_str.c_str(), + index_str.c_str(), deleted_str.c_str(), + false, max_elements, allow_replace_deleted, normalize); + // TODO(rescrv,sicheng): check integrity + // appr_alg->checkIntegrity(); + index_inited = true; + } + void persist_dirty() { if (!index_inited) @@ -122,7 +147,7 @@ class Index { if (!index_inited) { - throw std::runtime_error("Inde not inited"); + throw std::runtime_error("Index not inited"); } std::vector ret_data = appr_alg->template getDataByLabel(id); // This checks if id is deleted for (int i = 0; i < dim; i++) @@ -507,4 +532,132 @@ extern "C" { index->close_fd(); } + + // Memory buffer serialization functions + // Can throw std::exception + hnswlib::HnswData* serialize_to_hnsw_data(Index *index) + { + try + { + return index->appr_alg->serializeToHnswData(); + } + catch (std::exception &e) + { + last_error = e.what(); + return nullptr; + } + last_error.clear(); + } + + // Can throw std::exception + void load_index_from_hnsw_data(Index *index, const hnswlib::HnswData* data, const size_t max_elements) + { + if (index->index_inited) + { + last_error = "Index already inited"; + return; + } + + try + { + + index->appr_alg = new hnswlib::HierarchicalNSW(index->l2space, + data, false, + max_elements, false, + index->normalize); + index->index_inited = true; + } + catch (std::exception &e) + { + last_error = e.what(); + return; + } + last_error.clear(); + } + + // Memory buffer management functions + hnswlib::HnswData* create_hnsw_data() + { + return new hnswlib::HnswData(false /* own_buffers */); + } + + void free_hnsw_data(hnswlib::HnswData* data) + { + if (data) { + delete data; + } + } + + void set_header_buffer(hnswlib::HnswData* data, char* header_buffer, size_t header_size) + { + if (data) { + data->header_buffer = header_buffer; + data->header_size = header_size; + } + } + + void set_data_level0_buffer(hnswlib::HnswData* data, char* data_level0_buffer, size_t data_level0_size) + { + if (data) { + data->data_level0_buffer = data_level0_buffer; + data->data_level0_size = data_level0_size; + } + } + + void set_length_buffer(hnswlib::HnswData* data, char* length_buffer, size_t length_size) + { + if (data) { + data->length_buffer = length_buffer; + data->length_size = length_size; + } + } + + void set_link_list_buffer(hnswlib::HnswData* data, char* link_list_buffer, size_t link_list_size) + { + if (data) { + data->link_list_buffer = link_list_buffer; + data->link_list_size = link_list_size; + } + } + + // Buffer access functions for Rust FFI + char* get_header_buffer(hnswlib::HnswData* data) + { + return data ? data->header_buffer : nullptr; + } + + size_t get_header_buffer_size(hnswlib::HnswData* data) + { + return data ? data->header_size : 0; + } + + char* get_data_level0_buffer(hnswlib::HnswData* data) + { + return data ? data->data_level0_buffer : nullptr; + } + + size_t get_data_level0_buffer_size(hnswlib::HnswData* data) + { + return data ? data->data_level0_size : 0; + } + + char* get_length_buffer(hnswlib::HnswData* data) + { + return data ? data->length_buffer : nullptr; + } + + size_t get_length_buffer_size(hnswlib::HnswData* data) + { + return data ? data->length_size : 0; + } + + char* get_link_list_buffer(hnswlib::HnswData* data) + { + return data ? data->link_list_buffer : nullptr; + } + + size_t get_link_list_buffer_size(hnswlib::HnswData* data) + { + return data ? data->link_list_size : 0; + } } diff --git a/src/hnsw.rs b/src/hnsw.rs index 928f4ba4..da8a7caa 100644 --- a/src/hnsw.rs +++ b/src/hnsw.rs @@ -1,6 +1,6 @@ use std::{ - ffi::{c_char, c_int, CString}, - path::PathBuf, + ffi::{c_char, c_int, c_uchar, CString}, + path::{Path, PathBuf}, str::Utf8Error, }; use thiserror::Error; @@ -12,6 +12,13 @@ struct HnswIndexPtrFFI { _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, } +// Opaque struct for memory buffers +#[repr(C)] +pub struct HnswDataFFI { + _data: [u8; 0], + _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + #[link(name = "bindings", kind = "static")] extern "C" { fn create_index(space_name: *const c_char, dim: c_int) -> *const HnswIndexPtrFFI; @@ -68,6 +75,29 @@ extern "C" { fn capacity(index: *const HnswIndexPtrFFI) -> c_int; fn resize_index(index: *const HnswIndexPtrFFI, new_size: usize); fn get_last_error(index: *const HnswIndexPtrFFI) -> *const c_char; + + // Memory buffer functions + fn serialize_to_hnsw_data(index: *const HnswIndexPtrFFI) -> *const HnswDataFFI; + fn load_index_from_hnsw_data(index: *const HnswIndexPtrFFI, buffers: *const HnswDataFFI, max_elements: usize); + + // Memory buffer management + fn create_hnsw_data(own_buffers: bool) -> *const HnswDataFFI; + fn free_hnsw_data(buffers: *const HnswDataFFI); + + // Buffer access functions + fn get_header_buffer(buffers: *const HnswDataFFI) -> *const c_uchar; + fn get_header_buffer_size(buffers: *const HnswDataFFI) -> usize; + fn get_data_level0_buffer(buffers: *const HnswDataFFI) -> *const c_uchar; + fn get_data_level0_buffer_size(buffers: *const HnswDataFFI) -> usize; + fn get_length_buffer(buffers: *const HnswDataFFI) -> *const c_uchar; + fn get_length_buffer_size(buffers: *const HnswDataFFI) -> usize; + fn get_link_list_buffer(buffers: *const HnswDataFFI) -> *const c_uchar; + fn get_link_list_buffer_size(buffers: *const HnswDataFFI) -> usize; + + fn set_header_buffer(buffers: *const HnswDataFFI, buffer: *const c_uchar, size: usize); + fn set_data_level0_buffer(buffers: *const HnswDataFFI, buffer: *const c_uchar, size: usize); + fn set_length_buffer(buffers: *const HnswDataFFI, buffer: *const c_uchar, size: usize); + fn set_link_list_buffer(buffers: *const HnswDataFFI, buffer: *const c_uchar, size: usize); } #[derive(Error, Debug)] @@ -122,6 +152,7 @@ pub struct HnswIndexLoadConfig { pub dimensionality: i32, pub persist_path: PathBuf, pub ef_search: usize, + pub hnsw_data: HnswData, } pub struct HnswIndexInitConfig { @@ -332,6 +363,177 @@ impl HnswIndex { unsafe { set_ef(self.ffi_ptr, ef as c_int) } read_and_return_hnsw_error(self.ffi_ptr) } + + /// Serialize the index to memory buffers + pub fn serialize_to_hnsw_data(&self) -> Result { + let buffers_ptr = unsafe { serialize_to_hnsw_data(self.ffi_ptr) }; + read_and_return_hnsw_error(self.ffi_ptr)?; + + if buffers_ptr.is_null() { + return Err(HnswError::FFIException("Failed to serialize to memory buffers".to_string())); + } + + Ok(HnswData::new_from_ffi(buffers_ptr)) + } + + /// Load index from memory buffers + pub fn load_from_hnsw_data(config: HnswIndexLoadConfig) -> Result { + let distance_function_string: String = config.distance_function.into(); + let space_name = CString::new(distance_function_string) + .map_err(|e| HnswInitError::InvalidDistanceFunction(e.to_string()))?; + + let ffi_ptr = unsafe { create_index(space_name.as_ptr(), config.dimensionality) }; + read_and_return_hnsw_error(ffi_ptr)?; + + unsafe { + load_index_from_hnsw_data(ffi_ptr, config.hnsw_data.ffi_ptr, DEFAULT_MAX_ELEMENTS); + } + read_and_return_hnsw_error(ffi_ptr)?; + + let hnsw_index = HnswIndex { + ffi_ptr, + dimensionality: config.dimensionality, + }; + hnsw_index.set_ef(config.ef_search)?; + Ok(hnsw_index) + } +} + +/// Safe wrapper for memory buffers containing serialized HNSW index data +pub struct HnswData { + ffi_ptr: *const HnswDataFFI, + _marker: std::marker::PhantomData<*mut ()>, // Prevents Copy trait + // Hold Arc references to prevent premature buffer deallocation + header_buffer: Option>>, + data_level0_buffer: Option>>, + length_buffer: Option>>, + link_list_buffer: Option>>, +} + +unsafe impl Sync for HnswData {} +unsafe impl Send for HnswData {} + +impl Default for HnswData { + fn default() -> Self { + Self { + ffi_ptr: std::ptr::null(), + _marker: std::marker::PhantomData, + header_buffer: None, + data_level0_buffer: None, + length_buffer: None, + link_list_buffer: None, + } + } +} + +impl HnswData { + /// Create new empty memory buffers with ownership (default: owning) + pub fn new() -> Self { + Self::new_non_owning() + } + + /// Create new empty memory buffers that own their data + pub fn new_owning() -> Self { + let ffi_ptr = unsafe { create_hnsw_data(true) }; + HnswData::new_from_ffi(ffi_ptr) + } + + /// Create new empty memory buffers that do not own their data + pub fn new_non_owning() -> Self { + let ffi_ptr = unsafe { create_hnsw_data(false) }; + HnswData::new_from_ffi(ffi_ptr) + } + + /// Create new memory buffers from an existing FFI pointer + pub fn new_from_ffi(ffi_ptr: *const HnswDataFFI) -> Self { + HnswData { + ffi_ptr, + _marker: std::marker::PhantomData, + header_buffer: None, + data_level0_buffer: None, + length_buffer: None, + link_list_buffer: None, + } + } + + pub fn set_buffers( + &mut self, + header_buffer: std::sync::Arc>, + data_level0_buffer: std::sync::Arc>, + length_buffer: std::sync::Arc>, + link_list_buffer: std::sync::Arc> + ) { + unsafe { + set_header_buffer(self.ffi_ptr, header_buffer.as_ptr(), header_buffer.len()); + set_data_level0_buffer(self.ffi_ptr, data_level0_buffer.as_ptr(), data_level0_buffer.len()); + set_length_buffer(self.ffi_ptr, length_buffer.as_ptr(), length_buffer.len()); + set_link_list_buffer(self.ffi_ptr, link_list_buffer.as_ptr(), link_list_buffer.len()); + } + + // Store Arc references to prevent premature deallocation + self.header_buffer = Some(header_buffer); + self.data_level0_buffer = Some(data_level0_buffer); + self.length_buffer = Some(length_buffer); + self.link_list_buffer = Some(link_list_buffer); + } + + /// Get the header buffer as a byte slice + pub fn header_buffer(&self) -> &[u8] { + unsafe { + let ptr = get_header_buffer(self.ffi_ptr); + let size = get_header_buffer_size(self.ffi_ptr); + if ptr.is_null() || size == 0 { + &[] + } else { + std::slice::from_raw_parts(ptr as *const u8, size) + } + } + } + + /// Get the data level 0 buffer as a byte slice + pub fn data_level0_buffer(&self) -> &[u8] { + unsafe { + let ptr = get_data_level0_buffer(self.ffi_ptr); + let size = get_data_level0_buffer_size(self.ffi_ptr); + if ptr.is_null() || size == 0 { + &[] + } else { + std::slice::from_raw_parts(ptr as *const u8, size) + } + } + } + + /// Get the length buffer as a byte slice + pub fn length_buffer(&self) -> &[u8] { + unsafe { + let ptr = get_length_buffer(self.ffi_ptr); + let size = get_length_buffer_size(self.ffi_ptr); + if ptr.is_null() || size == 0 { + &[] + } else { + std::slice::from_raw_parts(ptr as *const u8, size) + } + } + } + + /// Get the link list buffer as a byte slice + pub fn link_list_buffer(&self) -> &[u8] { + unsafe { + let ptr = get_link_list_buffer(self.ffi_ptr); + let size = get_link_list_buffer_size(self.ffi_ptr); + if ptr.is_null() || size == 0 { + &[] + } else { + std::slice::from_raw_parts(ptr as *const u8, size) + } + } + } +} + +impl Drop for HnswData { + fn drop(&mut self) { + unsafe { if !self.ffi_ptr.is_null() { free_hnsw_data(self.ffi_ptr) } } + } } fn read_and_return_hnsw_error(ffi_ptr: *const HnswIndexPtrFFI) -> Result<(), HnswError> { @@ -355,6 +557,7 @@ impl Drop for HnswIndex { pub mod test { use std::fs::OpenOptions; use std::io::Write; + use std::sync::Arc; use super::*; use rand::seq::IteratorRandom; @@ -615,6 +818,7 @@ pub mod test { dimensionality: d as i32, persist_path: persist_path.to_path_buf(), ef_search: 100, + hnsw_data: HnswData::default(), }); let index = match index { @@ -637,6 +841,161 @@ pub mod test { index_data_same(&index, &ids, &data, d); } + #[test] + fn it_can_persist_and_load_from_memory() { + let n = 1000; + let d: usize = 960; + let distance_function = HnswDistanceFunction::Euclidean; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path(); + let index = HnswIndex::init(HnswIndexInitConfig { + distance_function, + dimensionality: d as i32, + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 100, + random_seed: 0, + persist_path: Some(persist_path.to_path_buf()), + }); + + let index = match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => index, + }; + + let data: Vec = generate_random_data(n, d); + let ids: Vec = (0..n).collect(); + + (0..n).for_each(|i| { + let data = &data[i * d..(i + 1) * d]; + index.add(ids[i], data).expect("Should not error"); + }); + + // Persist the index + let res = index.save(); + if let Err(e) = res { + panic!("Error saving index: {}", e); + } + + // Load the index from memory instead + let files = ["header", "data_level0", "length", "link_lists"]; + let ext = "bin"; + let mut src_buffers = Vec::new(); + + for file in files { + let path = persist_path.join(file).with_extension(ext); + let data = std::fs::read(path).expect("Unable to read file"); + src_buffers.push(Arc::new(data)); + } + + let mut hnsw_data = HnswData::new(); + hnsw_data.set_buffers(src_buffers[0].clone(), src_buffers[1].clone(), + src_buffers[2].clone(), src_buffers[3].clone()); + + let index = HnswIndex::load_from_hnsw_data(HnswIndexLoadConfig { + distance_function, + dimensionality: d as i32, + persist_path: "".into(), + ef_search: 100, + hnsw_data, + }); + + let index = match index { + Err(e) => panic!("Error loading index: {}", e), + Ok(index) => index, + }; + assert_eq!(index.get_ef().expect("Expected to get ef_search"), 100); + + // Query the data + let query = &data[0..d]; + let allow_ids = &[]; + let disallow_ids = &[]; + let (ids, distances) = index.query(query, 1, allow_ids, disallow_ids).unwrap(); + assert_eq!(ids.len(), 1); + assert_eq!(distances.len(), 1); + assert_eq!(ids[0], 0); + assert_eq!(distances[0], 0.0); + + // Get the data and check it + index_data_same(&index, &ids, &data, d); + } + + #[test] + fn it_can_serialize_and_deserialize_hnsw_data() { + let n = 100; + let d: usize = 128; + let distance_function = HnswDistanceFunction::Euclidean; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path(); + + // Create and populate original index + let original_index = HnswIndex::init(HnswIndexInitConfig { + distance_function, + dimensionality: d as i32, + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 100, + random_seed: 42, + persist_path: Some(persist_path.to_path_buf()), + }).expect("Failed to create original index"); + + let data: Vec = generate_random_data(n, d); + let ids: Vec = (0..n).collect(); + + // Add data to original index + for i in 0..n { + let data_slice = &data[i * d..(i + 1) * d]; + original_index.add(ids[i], data_slice).expect("Should not error"); + } + + // Verify original index has correct data + assert_eq!(original_index.len(), n); + index_data_same(&original_index, &ids, &data, d); + + // Serialize to memory buffers + let hnsw_data = original_index.serialize_to_hnsw_data() + .expect("Failed to serialize to memory buffers"); + + // Verify buffers are not empty + assert!(!hnsw_data.header_buffer().is_empty(), "Header buffer should not be empty"); + assert!(!hnsw_data.data_level0_buffer().is_empty(), "Data level0 buffer should not be empty"); + assert!(!hnsw_data.length_buffer().is_empty(), "Length buffer should not be empty"); + assert!(!hnsw_data.link_list_buffer().is_empty(), "Link list buffer should not be empty"); + + // Create new index from memory buffers + let loaded_index = HnswIndex::load_from_hnsw_data( + HnswIndexLoadConfig { + distance_function, + dimensionality: d as i32, + persist_path: "".into(), + ef_search: 100, + hnsw_data, + }, + ).expect("Failed to load from memory buffers"); + + // Verify loaded index has same data + assert_eq!(loaded_index.len(), n); + index_data_same(&loaded_index, &ids, &data, d); + + // Test querying both indices to ensure they behave the same + let query_vector = &data[0..d]; // Use first vector as query + let k = 5; + + let (original_ids, original_distances) = original_index.query(query_vector, k, &[], &[]) + .expect("Query should not error"); + let (loaded_ids, loaded_distances) = loaded_index.query(query_vector, k, &[], &[]) + .expect("Query should not error"); + + // Results should be identical + assert_eq!(original_ids, loaded_ids, "Query results should be identical"); + assert_eq!(original_distances.len(), loaded_distances.len()); + for (orig_dist, loaded_dist) in original_distances.iter().zip(loaded_distances.iter()) { + assert!((orig_dist - loaded_dist).abs() < EPS, "Distances should be nearly identical"); + } + } + #[test] fn it_can_add_and_query_with_allowed_and_disallowed_ids() { let n = 1000; @@ -817,6 +1176,7 @@ pub mod test { dimensionality: d as i32, persist_path: persist_path.to_path_buf(), ef_search: 100, + hnsw_data: HnswData::default(), }); assert!(index.is_err()); diff --git a/tests/cpp/persistent_test.cpp b/tests/cpp/persistent_test.cpp index d93e66af..de8a3ec0 100644 --- a/tests/cpp/persistent_test.cpp +++ b/tests/cpp/persistent_test.cpp @@ -1,4 +1,5 @@ #include "../../hnswlib/hnswlib.h" +#include "hnswlib/hnswalg.h" #include @@ -81,6 +82,63 @@ namespace delete alg_hnsw; } + void testMemorySerialization() + { + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + int d = 1536; + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; i++) + { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) + { + query[i] = distrib(rng); + } + + hnswlib::InnerProductSpace space(d); + hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n, 16, 200, 100, false, false, true, "."); + + for (size_t i = 0; i < n; i++) + { + alg_hnsw->addPoint(data.data() + d * i, i); + if (i % 7 == 0) + alg_hnsw->persistDirty(); + } + + hnswlib::HnswData* ser_data = nullptr; + alg_hnsw->saveToHnswData(&ser_data); + assert(ser_data); + alg_hnsw->persistDirty(); + + + // This test owns the buffer memory now so delete alg_hnsw + delete alg_hnsw; + + assert(ser_data->matchesWithDirectory(".")); + + hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ser_data, false, 2 * n, false, false); + + hnswlib::HnswData* ser_data2 = nullptr; + alg_hnsw2->saveToHnswData(&ser_data2); + delete alg_hnsw2; + assert(ser_data2); + assert(ser_data2->matchesWithDirectory(".")); + + + delete ser_data; + delete ser_data2; + } + void testResizePersistentIndex() { int d = 1536; @@ -530,6 +588,8 @@ int main() std::cout << "Testing ..." << std::endl; testPersistentIndex(); std::cout << "Test testPersistentIndex ok" << std::endl; + testMemorySerialization(); + std::cout << "Test testMemorySerialization ok" << std::endl; testResizePersistentIndex(); std::cout << "Test testResizePersistentIndex ok" << std::endl; testAddUpdatePersistentIndex();