diff --git a/docs/docs/en/src/advanced/introspection.md b/docs/docs/en/src/advanced/introspection.md index c23e62c4d..e55d2a7e9 100644 --- a/docs/docs/en/src/advanced/introspection.md +++ b/docs/docs/en/src/advanced/introspection.md @@ -45,7 +45,7 @@ Two overloads are provided: // Dense vector indexes (HGraph, BruteForce, IVF, DiskANN, HNSW) auto r = index->CalDistanceById(query_ptr, ids, count, /*calculate_precise_distance=*/true); -// Sparse vector indexes (SINDI, SparseIndex) — wrap the query in a Dataset +// Sparse vector indexes (SINDI) — wrap the query in a Dataset auto query_ds = vsag::Dataset::Make(); query_ds->NumElements(1)->SparseVectors(/* ... */); auto r = index->CalDistanceById(query_ds, ids, count, /*calculate_precise_distance=*/true); diff --git a/docs/docs/en/src/advanced/search_allocator.md b/docs/docs/en/src/advanced/search_allocator.md index 5a4b76ebf..6a542946c 100644 --- a/docs/docs/en/src/advanced/search_allocator.md +++ b/docs/docs/en/src/advanced/search_allocator.md @@ -43,7 +43,7 @@ index falls back to the allocator that was attached to its owning `Resource`. > **Availability.** `Index::SearchWithRequest` has a default implementation that returns an > *unsupported* error. Only HGraph, IVF, BruteForce and WARP implement it today > (`src/algorithm/{hgraph,ivf,brute_force,warp}.cpp`). For indexes that do not yet override -> `SearchWithRequest` (HNSW, DiskANN, SINDI, Pyramid, SparseIndex), use the legacy `SearchParam` +> `SearchWithRequest` (HNSW, DiskANN, SINDI, Pyramid), use the legacy `SearchParam` > path described below. ## Legacy API — `SearchParam::allocator` *(deprecated)* diff --git a/docs/docs/zh/src/advanced/introspection.md b/docs/docs/zh/src/advanced/introspection.md index f4eb5a678..3d5cb7a42 100644 --- a/docs/docs/zh/src/advanced/introspection.md +++ b/docs/docs/zh/src/advanced/introspection.md @@ -41,7 +41,7 @@ if (not index->CheckFeature(vsag::SUPPORT_DELETE_BY_ID)) { // 稠密向量索引(HGraph、BruteForce、IVF、DiskANN、HNSW) auto r = index->CalDistanceById(query_ptr, ids, count, /*calculate_precise_distance=*/true); -// 稀疏向量索引(SINDI、SparseIndex)—— 用 Dataset 封装查询 +// 稀疏向量索引(SINDI)—— 用 Dataset 封装查询 auto query_ds = vsag::Dataset::Make(); query_ds->NumElements(1)->SparseVectors(/* ... */); auto r = index->CalDistanceById(query_ds, ids, count, /*calculate_precise_distance=*/true); diff --git a/docs/docs/zh/src/advanced/search_allocator.md b/docs/docs/zh/src/advanced/search_allocator.md index cd5365f16..d0f8ee7ba 100644 --- a/docs/docs/zh/src/advanced/search_allocator.md +++ b/docs/docs/zh/src/advanced/search_allocator.md @@ -37,7 +37,7 @@ auto result = index->SearchWithRequest(req).value(); > **可用性。** `Index::SearchWithRequest` 默认实现会返回 *不支持* 错误。目前只有 HGraph、 > IVF、BruteForce、WARP 实现了它(`src/algorithm/{hgraph,ivf,brute_force,warp}.cpp`)。对于 -> 尚未 override 的索引(HNSW、DiskANN、SINDI、Pyramid、SparseIndex),请使用下文的旧版 +> 尚未 override 的索引(HNSW、DiskANN、SINDI、Pyramid),请使用下文的旧版 > `SearchParam` 路径。 ## 旧版 API —— `SearchParam::allocator`(已弃用) diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 7a3270f4c..32a66d705 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -23,6 +23,7 @@ extern const char* const INDEX_FRESH_HNSW; extern const char* const INDEX_PYRAMID; extern const char* const INDEX_SPARSE; extern const char* const INDEX_SINDI; +extern const char* const INDEX_DISKSINDI; extern const char* const INDEX_BRUTE_FORCE; extern const char* const INDEX_IVF; extern const char* const INDEX_WARP; diff --git a/include/vsag/index.h b/include/vsag/index.h index 8db07640f..9ed005d4d 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -51,7 +51,20 @@ struct MergeUnit { IdMapFunction id_map_func = nullptr; }; -enum class IndexType { HNSW, DISKANN, HGRAPH, IVF, PYRAMID, BRUTEFORCE, SPARSE, SINDI, WARP }; +enum class IndexType { + HNSW = 0, + DISKANN = 1, + HGRAPH = 2, + IVF = 3, + PYRAMID = 4, + BRUTEFORCE = 5, + // Kept for source compatibility with SparseIndex callers; new sparse workloads should + // prefer SINDI or DISKSINDI. + SPARSE = 6, + SINDI = 7, + WARP = 8, + DISKSINDI = 9, +}; #define DATA_FLAG_FLOAT32_VECTOR 0x01 #define DATA_FLAG_INT8_VECTOR 0x02 @@ -462,7 +475,7 @@ class Index { * * Suitable for dense vector indexes (HGraph, BruteForce, IVF, DiskANN, HNSW). * The query must be a contiguous float32 array with dimension matching the index. - * For sparse vector indexes (SINDI, SparseIndex), this overload is not applicable; + * For sparse vector indexes (SINDI, DiskSINDI), this overload is not applicable; * use CalcDistanceById(DatasetPtr, int64_t, bool) instead. * * @param vector The embedding of the query (float32 array for dense vectors). @@ -483,7 +496,7 @@ class Index { /** * @brief Calculate the distance between the query and the vector of the given ID. * - * Suitable for sparse vector indexes (SINDI, SparseIndex) where vectors + * Suitable for sparse vector indexes (SINDI, DiskSINDI) where vectors * cannot be represented as a simple float pointer. The Dataset should * contain sparse vectors via GetSparseVectors(). * For dense vector indexes (HGraph, BruteForce, IVF, DiskANN, HNSW), @@ -509,7 +522,7 @@ class Index { * * Suitable for dense vector indexes (HGraph, BruteForce, IVF, DiskANN, HNSW). * The query must be a contiguous float32 array. For sparse vector indexes - * (SINDI, SparseIndex), this overload is not applicable; use + * (SINDI, DiskSINDI), this overload is not applicable; use * CalDistanceById(DatasetPtr, const int64_t*, int64_t, bool) instead. * * @param query is the embedding of query (float32 array for dense vectors). @@ -532,7 +545,7 @@ class Index { /** * @brief Calculate the distance between the query and the vector of the given ID for batch. * - * Suitable for sparse vector indexes (SINDI, SparseIndex) where vectors + * Suitable for sparse vector indexes (SINDI, DiskSINDI) where vectors * cannot be represented as a simple float pointer. The Dataset should * contain sparse vectors via GetSparseVectors(). * For dense vector indexes (HGraph, BruteForce, IVF, DiskANN, HNSW), diff --git a/src/algorithm/CMakeLists.txt b/src/algorithm/CMakeLists.txt index bd6849d2c..6b8cdf4fe 100644 --- a/src/algorithm/CMakeLists.txt +++ b/src/algorithm/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory (hnswlib) add_subdirectory (sindi) +add_subdirectory (disksindi) add_subdirectory (hgraph) add_subdirectory (bruteforce) add_subdirectory (ivf) @@ -35,6 +36,7 @@ set (ALGORITHM_LIBS algorithm hnswlib sindi + disksindi hgraph bruteforce ivf diff --git a/src/algorithm/disksindi/CMakeLists.txt b/src/algorithm/disksindi/CMakeLists.txt new file mode 100644 index 000000000..8f1bf0bfe --- /dev/null +++ b/src/algorithm/disksindi/CMakeLists.txt @@ -0,0 +1,23 @@ + +# Copyright 2024-present the vsag project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set(DISKSINDI_SRCS + disksindi.cpp + disksindi_parameter.cpp +) + +add_library(disksindi OBJECT ${DISKSINDI_SRCS}) +target_link_libraries(disksindi PRIVATE coverage_config vsag_src_common) diff --git a/src/algorithm/disksindi/disksindi.cpp b/src/algorithm/disksindi/disksindi.cpp new file mode 100644 index 000000000..097d884e9 --- /dev/null +++ b/src/algorithm/disksindi/disksindi.cpp @@ -0,0 +1,1048 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "disksindi.h" + +#include +#include +#include +#include +#include +#include + +#include "algorithm/sparse_distance.h" +#include "datacell/sparse_vector_datacell_parameter.h" +#include "impl/filter/inner_id_wrapper_filter.h" +#include "impl/filter/white_list_filter.h" +#include "impl/heap/standard_heap.h" +#include "index_feature_list.h" +#include "inner_string_params.h" +#include "io/reader_io_parameter.h" +#include "quantization/sparse_quantization/sparse_quantizer.h" +#include "quantization/sparse_quantization/sparse_quantizer_parameter.h" +#include "storage/empty_index_binary_set.h" +#include "storage/serialization.h" +#include "utils/util_functions.h" +#include "vsag/allocator.h" +#include "vsag/options.h" +#include "vsag_exception.h" + +namespace vsag { + +namespace { + +constexpr const char* DISKSINDI_RERANK_FLAT_FORMAT_KEY = "disksindi_rerank_flat_format"; +constexpr int64_t DISKSINDI_RERANK_FLAT_FORMAT_DATACELL = 2; + +class BinaryReader : public Reader { +public: + explicit BinaryReader(Binary binary) : binary_(std::move(binary)) { + } + + void + Read(uint64_t offset, uint64_t len, void* dest) override { + std::memcpy(dest, binary_.data.get() + offset, len); + } + + void + AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override { + Read(offset, len, dest); + callback(IOErrorCode::IO_SUCCESS, "success"); + } + + [[nodiscard]] uint64_t + Size() const override { + return binary_.size; + } + +private: + Binary binary_; +}; + +class StreamBackedReader : public Reader { +public: + explicit StreamBackedReader(std::istream& stream) : stream_(stream) { + auto cursor = stream_.tellg(); + stream_.seekg(0, std::ios::end); + size_ = static_cast(stream_.tellg()); + stream_.seekg(cursor, std::ios::beg); + } + + void + Read(uint64_t offset, uint64_t len, void* dest) override { + std::lock_guard lock(mutex_); + stream_.seekg(static_cast(offset), std::ios::beg); + stream_.read(static_cast(dest), static_cast(len)); + } + + void + AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override { + Read(offset, len, dest); + callback(IOErrorCode::IO_SUCCESS, "success"); + } + + [[nodiscard]] uint64_t + Size() const override { + return size_; + } + +private: + std::istream& stream_; + uint64_t size_{0}; + std::mutex mutex_; +}; + +float +compute_distance_from_codes(const uint8_t* codes, + const Vector& sorted_ids, + const Vector& sorted_vals) { + auto len = *reinterpret_cast(codes); + const auto* entries = reinterpret_cast(codes + sizeof(uint32_t)); + float sum = 0.0F; + uint32_t i = 0; + uint32_t j = 0; + while (i < sorted_ids.size() && j < len) { + if (sorted_ids[i] < entries[j].id) { + i++; + } else if (sorted_ids[i] > entries[j].id) { + j++; + } else { + sum += sorted_vals[i] * entries[j].val; + i++; + j++; + } + } + return 1 - sum; +} + +float +cal_distance_by_id_unsafe(const FlattenInterfacePtr& flat, + const Vector& sorted_ids, + const Vector& sorted_vals, + uint32_t inner_id) { + bool need_release{false}; + const auto* codes = flat->GetCodesById(inner_id, need_release); + float distance = compute_distance_from_codes(codes, sorted_ids, sorted_vals); + if (need_release) { + flat->Release(codes); + } + return distance; +} + +DatasetPtr +collect_results(const DistHeapPtr& results, Allocator* allocator) { + auto [result, dists, ids] = + create_fast_dataset(static_cast(results->Size()), allocator); + if (results->Empty()) { + result->Dim(0)->NumElements(1); + return result; + } + + for (auto j = static_cast(results->Size() - 1); j >= 0; --j) { + dists[j] = results->Top().first; + ids[j] = results->Top().second; + results->Pop(); + } + return result; +} + +Vector +collect_query_term_ids(const SparseTermComputerPtr& computer, Allocator* allocator) { + Vector query_term_ids(allocator); + while (computer->HasNextTerm()) { + auto it = computer->NextTermIter(); + query_term_ids.push_back(computer->GetTerm(it)); + } + computer->ResetTerm(); + return query_term_ids; +} + +uint64_t +block_memory_ceil(uint64_t memory) { + const auto block_size = Options::Instance().block_size_limit(); + return ((memory + block_size - 1) / block_size) * block_size; +} + +FlattenInterfacePtr +create_rerank_flat(const IndexCommonParam& common_param, const IOParamPtr& io_param) { + auto rerank_param = std::make_shared(); + rerank_param->io_parameter = io_param; + rerank_param->quantizer_parameter = std::make_shared(); + return FlattenInterface::MakeInstance(rerank_param, common_param); +} + +void +deserialize_legacy_rerank_flat(StreamReader& reader, + const FlattenInterfacePtr& flat, + Allocator* allocator) { + int64_t cur_element_count = 0; + StreamReader::ReadObj(reader, cur_element_count); + flat->Resize(cur_element_count); + std::vector ids; + std::vector vals; + for (int64_t i = 0; i < cur_element_count; ++i) { + uint32_t len = 0; + StreamReader::ReadObj(reader, len); + ids.resize(len); + vals.resize(len); + reader.Read(reinterpret_cast(ids.data()), + static_cast(len) * sizeof(uint32_t)); + reader.Read(reinterpret_cast(vals.data()), + static_cast(len) * sizeof(float)); + SparseVector vector; + vector.len_ = len; + vector.ids_ = ids.data(); + vector.vals_ = vals.data(); + flat->InsertVector(&vector, i); + } + LabelTable legacy_label_table(allocator); + legacy_label_table.Deserialize(reader); +} + +void +deserialize_rerank_flat(StreamReader& reader, + const FlattenInterfacePtr& flat, + Allocator* allocator, + bool has_datacell_format) { + if (has_datacell_format) { + flat->Deserialize(reader); + return; + } + deserialize_legacy_rerank_flat(reader, flat, allocator); +} + +} // namespace + +ParamPtr +DiskSINDI::CheckAndMappingExternalParam(const JsonType& external_param, + const IndexCommonParam& common_param) { + auto ptr = std::make_shared(); + ptr->FromJson(external_param); + return ptr; +} + +DiskSINDI::DiskSINDI(const DiskSINDIParameterPtr& param, const IndexCommonParam& common_param) + : InnerIndexInterface(param, common_param), + use_reorder_(param->use_reorder), + use_quantization_(param->use_quantization), + term_id_limit_(param->term_id_limit), + window_size_(param->window_size), + doc_retain_ratio_(1.0F - param->doc_prune_ratio), + deserialize_without_footer_(param->deserialize_without_footer), + deserialize_without_buffer_(param->deserialize_without_buffer), + quantization_params_(std::make_shared()), + avg_doc_term_length_(param->avg_doc_term_length), + remap_term_ids_(param->remap_term_ids), + param_(param) { + CHECK_ARGUMENT(window_size_ > 0, "window_size must be in (0, 65536]"); + CHECK_ARGUMENT(window_size_ <= 65536, "window_size must be in (0, 65536]"); + CHECK_ARGUMENT(term_id_limit_ > 0, "term_id_limit must be > 0"); + if (remap_term_ids_) { + term_id_mapper_ = + std::make_shared(term_id_limit_, common_param.allocator_.get()); + } + if (use_reorder_) { + rerank_flat_ = create_rerank_flat(common_param, param->rerank_io_parameter); + } + term_datacell_ = DiskSparseTermListDataCellInterface::MakeInstance(doc_retain_ratio_, + term_id_limit_, + allocator_, + use_quantization_, + quantization_params_, + window_size_, + param->term_io_parameter, + common_param); +} + +std::string +DiskSINDI::GetStats() const { + return ""; +} + +std::vector +DiskSINDI::Add(const DatasetPtr& base, AddMode mode) { + std::scoped_lock wlock(this->global_mutex_); + + if (is_deserialized_) { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "DiskSINDI does not support Add after Deserialize"); + } + + std::vector failed_ids; + + auto data_num = base->GetNumElements(); + CHECK_ARGUMENT(data_num > 0, "data_num is zero when add vectors"); + + const auto* sparse_vectors = base->GetSparseVectors(); + const auto* ids = base->GetIds(); + const auto* extra_info = base->GetExtraInfos(); + const auto extra_info_size = base->GetExtraInfoSize(); + + if (use_quantization_ && cur_element_count_ == 0) { + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + for (int64_t i = 0; i < data_num; ++i) { + const auto& vec = sparse_vectors[i]; + for (int j = 0; j < vec.len_; ++j) { + float val = vec.vals_[j]; + if (val < min_val) { + min_val = val; + } + if (val > max_val) { + max_val = val; + } + } + } + quantization_params_->min_val = min_val; + quantization_params_->max_val = max_val; + quantization_params_->diff = max_val - min_val; + if (quantization_params_->diff < 1e-6) { + quantization_params_->diff = 1.0F; + } + } + + Vector tmp_ids(allocator_); + for (uint32_t i = 0; i < data_num; ++i) { + const auto& sparse_vector = sparse_vectors[i]; + if (label_table_->CheckLabel(ids[i])) { + failed_ids.push_back(ids[i]); + logger::warn("id ({}) already exists", ids[i]); + continue; + } + if (sparse_vector.len_ <= 0) { + failed_ids.push_back(ids[i]); + logger::warn( + "sparse_vector.len_ ({}) is invalid for id ({})", sparse_vector.len_, ids[i]); + continue; + } + + try { + if (remap_term_ids_) { + auto remapped = remap_sparse_vector_for_build(sparse_vector, tmp_ids); + term_datacell_->InsertVector(remapped, static_cast(cur_element_count_)); + } else { + term_datacell_->InsertVector(sparse_vector, + static_cast(cur_element_count_)); + } + } catch (const std::runtime_error& e) { + failed_ids.push_back(ids[i]); + logger::warn("runtime error: {}", e.what()); + continue; + } catch (const VsagException& e) { + failed_ids.push_back(ids[i]); + logger::warn("vsag exception: {}", e.what()); + continue; + } catch (const std::bad_alloc& e) { + failed_ids.push_back(ids[i]); + logger::warn("memory allocation failed: {}", e.what()); + continue; + } + + label_table_->Insert(cur_element_count_, ids[i]); + + if (extra_info_size > 0) { + extra_infos_->InsertExtraInfo(extra_info + i * extra_info_size, cur_element_count_); + } + + if (use_reorder_) { + rerank_flat_->InsertVector(sparse_vectors + i, cur_element_count_); + } + + cur_element_count_++; + } + + this->cal_memory_usage(); + return failed_ids; +} + +std::vector +DiskSINDI::Build(const DatasetPtr& base) { + auto failed_ids = this->Add(base); + + auto window_count = + static_cast(align_up(cur_element_count_, window_size_) / window_size_); + term_datacell_->FinalizeTermBuffers(window_count); + this->cal_memory_usage(); + return failed_ids; +} + +bool +DiskSINDI::UpdateVector(int64_t id, const DatasetPtr& new_base, bool force_update) { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "DiskSINDI does not support UpdateVector"); +} + +DatasetPtr +DiskSINDI::KnnSearch(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const FilterPtr& filter) const { + return KnnSearch(query, k, parameters, filter, allocator_); +} + +DatasetPtr +DiskSINDI::KnnSearch(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const FilterPtr& filter, + vsag::Allocator* allocator) const { + std::shared_lock rlock(this->global_mutex_); + + const auto* sparse_vectors = query->GetSparseVectors(); + CHECK_ARGUMENT(query->GetNumElements() == 1, "num of query should be 1"); + CHECK_ARGUMENT(k > 0, "k must be greater than 0"); + auto sparse_query = sparse_vectors[0]; + CHECK_ARGUMENT( + sparse_query.len_ > 0, + fmt::format("query->GetSparseVectors()->len_ ({}) is invalid", sparse_query.len_)); + + DiskSINDISearchParameter search_param; + search_param.FromJson(JsonType::Parse(parameters)); + CHECK_ARGUMENT(search_param.n_candidate <= SPARSE_AMPLIFICATION_FACTOR * k, + fmt::format("n_candidate ({}) should be less than {} * k ({})", + search_param.n_candidate, + SPARSE_AMPLIFICATION_FACTOR, + k)); + InnerSearchParam inner_param; + inner_param.ef = std::max(static_cast(search_param.n_candidate), k); + inner_param.topk = k; + + FilterPtr ft = nullptr; + if (filter != nullptr) { + ft = std::make_shared(filter, *this->label_table_); + } + inner_param.is_inner_id_allowed = ft; + + SparseVector effective_query = sparse_query; + Vector tmp_ids(allocator_); + Vector tmp_vals(allocator_); + if (remap_term_ids_) { + effective_query = remap_sparse_vector_for_query(sparse_query, tmp_ids, tmp_vals); + if (effective_query.len_ == 0) { + auto [results, ret_dists, ret_ids] = create_fast_dataset(0, allocator); + return results; + } + } + + auto computer = std::make_shared(effective_query, search_param, allocator); + const SparseVector* rerank_query = (remap_term_ids_ && use_reorder_) ? &sparse_query : nullptr; + + // Collect query term ids for lazy loading + auto query_term_ids = collect_query_term_ids(computer, allocator); + auto query_term_buffers = term_datacell_->LoadQueryTermBuffers(query_term_ids); + + return search_impl(computer, + inner_param, + allocator, + search_param.use_term_lists_heap_insert, + query_term_buffers, + rerank_query); +} + +template +DatasetPtr +DiskSINDI::search_impl(const SparseTermComputerPtr& computer, + const InnerSearchParam& inner_param, + Allocator* allocator, + bool use_term_lists_heap_insert, + const QueryTermBuffers& query_term_buffers, + const SparseVector* original_query) const { + MaxHeap heap(allocator); + int64_t k = 0; + + if constexpr (mode == KNN_SEARCH) { + k = inner_param.topk; + } + + Vector dists(window_size_, 0.0, allocator); + auto filter = inner_param.is_inner_id_allowed; + const auto [min_window_id, max_window_id] = this->get_min_max_window_id(filter); + + for (auto cur = min_window_id; cur <= max_window_id; cur++) { + auto window_start_id = static_cast(cur) * window_size_; + + term_datacell_->QueryWindow( + dists.data(), static_cast(cur), computer, query_term_buffers); + + if (use_term_lists_heap_insert) { + term_datacell_->InsertHeapByWindowKnn(dists.data(), + static_cast(cur), + computer, + heap, + inner_param, + window_start_id, + inner_param.is_inner_id_allowed != nullptr, + query_term_buffers); + } else { + term_datacell_->InsertHeapByDistsKnn(dists.data(), + dists.size(), + heap, + inner_param, + window_start_id, + inner_param.is_inner_id_allowed != nullptr); + } + } + + // rerank + if (use_reorder_) { + float cur_heap_top = std::numeric_limits::max(); + auto candidate_size = heap.size(); + auto high_precise_heap = std::make_shared>(allocator_, -1); + auto [sorted_ids, sorted_vals] = + sort_sparse_vector(original_query ? *original_query : computer->raw_query_, allocator_); + + // Phase B: collect all candidate inner ids, sort them by inner_id + // (ascending), then issue a single batched IO. Sorting makes disk + // offsets monotonically increasing, which enables + // GetCodesByIdsBatch to merge adjacent IO requests and reduces + // syscall count for DirectIO. + Vector cand_ids(allocator_); + cand_ids.resize(candidate_size); + for (int64_t i = static_cast(candidate_size) - 1; i >= 0; --i) { + cand_ids[i] = heap.top().second; + heap.pop(); + } + std::sort(cand_ids.begin(), cand_ids.end()); + auto batch = rerank_flat_->GetCodesByIdsBatch( + cand_ids.data(), static_cast(candidate_size), allocator_); + + for (uint64_t i = 0; i < candidate_size; i++) { + auto inner_id = cand_ids[i]; + const uint8_t* codes = batch.buffer.data() + batch.in_buffer_offsets[i]; + auto high_precise_distance = + compute_distance_from_codes(codes, sorted_ids, sorted_vals); + auto label = label_table_->GetLabelById(inner_id); + if constexpr (mode == KNN_SEARCH) { + if (high_precise_distance < cur_heap_top or + high_precise_heap->Size() < static_cast(k)) { + high_precise_heap->Push(high_precise_distance, label); + } + if (high_precise_heap->Size() > static_cast(k)) { + high_precise_heap->Pop(); + } + cur_heap_top = high_precise_heap->Top().first; + } + if constexpr (mode == RANGE_SEARCH) { + if (high_precise_distance <= inner_param.radius) { + high_precise_heap->Push(high_precise_distance, label); + } + if (inner_param.range_search_limit_size != -1 and + high_precise_heap->Size() > + static_cast(inner_param.range_search_limit_size)) { + high_precise_heap->Pop(); + } + } + } + + return collect_results(high_precise_heap, allocator_); + } + + // low precision + if constexpr (mode == RANGE_SEARCH) { + k = static_cast(heap.size()); + if (inner_param.range_search_limit_size != -1) { + k = inner_param.range_search_limit_size; + } + } + + int64_t cur_size = std::min(static_cast(heap.size()), k); + + auto [results, ret_dists, ret_ids] = create_fast_dataset(cur_size, allocator_); + if (cur_size == 0) { + return results; + } + + while (heap.size() > k) { + heap.pop(); + } + + for (auto j = cur_size - 1; j >= 0; j--) { + ret_dists[j] = 1 + heap.top().first; + ret_ids[j] = label_table_->GetLabelById(heap.top().second); + heap.pop(); + } + + return results; +} + +DatasetPtr +DiskSINDI::RangeSearch(const DatasetPtr& query, + float radius, + const std::string& parameters, + const FilterPtr& filter, + int64_t limited_size) const { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "DiskSINDI does not support RangeSearch in stage 2"); +} + +void +DiskSINDI::cal_memory_usage() { + auto memory = sizeof(DiskSINDI); + if (term_datacell_ != nullptr) { + memory += term_datacell_->GetMemoryUsage(); + } + if (this->rerank_flat_ != nullptr) { + memory += this->rerank_flat_->GetMemoryUsage(); + } + memory += sizeof(QuantizationParams); + + std::unique_lock lock(this->memory_usage_mutex_); + this->current_memory_usage_.store(static_cast(memory)); +} + +void +DiskSINDI::Serialize(StreamWriter& writer) const { + std::shared_lock rlock(this->global_mutex_); + + // Pre-compute sizes for two-pass serialization (StreamWriter lacks Seek/WriteAt) + uint64_t metadata_start = writer.GetCursor(); + + uint64_t quantization_size = use_quantization_ ? 3 * sizeof(float) : 0; + uint64_t term_dict_size = (term_id_limit_ + 1) * sizeof(DiskTermEntry); + uint64_t payload_size = term_datacell_->ComputePayloadSize(); + + // Compute segment offsets + DiskSINDIManifest manifest{}; + manifest.term_dict_offset = + metadata_start + sizeof(manifest) + sizeof(cur_element_count_) + quantization_size; + manifest.term_dict_size = term_dict_size; + manifest.posting_payload_offset = manifest.term_dict_offset + term_dict_size; + manifest.posting_payload_size = payload_size; + + uint64_t cursor_after_payload = manifest.posting_payload_offset + payload_size; + + if (use_reorder_) { + manifest.rerank_flat_offset = cursor_after_payload; + manifest.rerank_flat_size = rerank_flat_->CalcSerializeSize(); + cursor_after_payload += manifest.rerank_flat_size; + } + + manifest.label_table_offset = cursor_after_payload; + std::stringstream label_ss; + { + IOStreamWriter label_writer(label_ss); + label_table_->Serialize(label_writer); + } + manifest.label_table_size = label_ss.tellp(); + + // Pass 2: write everything in order + writer.Write(reinterpret_cast(&manifest), sizeof(manifest)); + StreamWriter::WriteObj(writer, cur_element_count_); + if (use_quantization_) { + StreamWriter::WriteObj(writer, quantization_params_->min_val); + StreamWriter::WriteObj(writer, quantization_params_->max_val); + StreamWriter::WriteObj(writer, quantization_params_->diff); + } + + term_datacell_->WriteTermDictAndPayload(writer, manifest.posting_payload_offset); + + if (use_reorder_) { + rerank_flat_->Serialize(writer); + } + + writer.Write(label_ss.str().c_str(), manifest.label_table_size); + + if (remap_term_ids_ && term_id_mapper_) { + term_id_mapper_->Serialize(writer); + } + + // Footer + JsonType jsonify_basic_info; + auto metadata = std::make_shared(); + jsonify_basic_info[INDEX_PARAM].SetString(this->create_param_ptr_->ToString()); + if (use_reorder_) { + jsonify_basic_info[DISKSINDI_RERANK_FLAT_FORMAT_KEY].SetInt( + DISKSINDI_RERANK_FLAT_FORMAT_DATACELL); + } + metadata->Set("basic_info", jsonify_basic_info); + auto footer = std::make_shared