Skip to content

Commit 34eb9f8

Browse files
authored
refactor(algorithm): extract common search helpers to InnerIndexInterface base class (#2139)
refactor(index): extract common search helpers into InnerIndexInterface Extract repetitive validation, filter composition, result packing, and serialization logic into protected helper methods, eliminating ~254 lines of duplicated boilerplate across 6 index implementations. New helpers in InnerIndexInterface: - validate_knn_args / validate_range_args / validate_search_query - create_search_filter (supports InnerIdWrapperFilter & ExtraInfoWrapperFilter) - pack_knn_result / pack_knn_result_with_extra_info / make_empty_result - write_index_footer / read_index_footer Applied to: BruteForce, HGraph, IVF, Pyramid, SINDI, and WARP. Assisted-by: OpenCode:claude-opus-4.7 Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent 674eb3b commit 34eb9f8

8 files changed

Lines changed: 236 additions & 255 deletions

File tree

src/algorithm/bruteforce/bruteforce.cpp

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,7 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
210210
ExecutorPtr executor = nullptr;
211211
Filter* attr_filter = nullptr;
212212

213-
auto combined_filter = std::make_shared<CombinedFilter>();
214-
combined_filter->AppendFilter(this->label_table_->GetDeletedIdsFilter());
215-
if (request.filter_ != nullptr) {
216-
combined_filter->AppendFilter(
217-
std::make_shared<InnerIdWrapperFilter>(request.filter_, *this->label_table_));
218-
}
219-
FilterPtr ft = nullptr;
220-
if (not combined_filter->IsEmpty()) {
221-
ft = combined_filter;
222-
}
213+
FilterPtr ft = this->create_search_filter(request.filter_);
223214

224215
if (request.enable_attribute_filter_) {
225216
auto& schema = this->attr_filter_index_->field_type_map_;
@@ -277,19 +268,13 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
277268
}
278269
}
279270

280-
auto [dataset_results, dists, ids] =
281-
create_fast_dataset(static_cast<int64_t>(heap->Size()), allocator_);
282-
for (auto j = static_cast<int64_t>(heap->Size() - 1); j >= 0; --j) {
283-
dists[j] = heap->Top().first;
284-
ids[j] = this->label_table_->GetLabelById(heap->Top().second);
285-
heap->Pop();
286-
}
271+
auto result = this->pack_knn_result(heap);
287272

288273
JsonType stats;
289274
stats["dist_cmp"].SetInt(dist_cmp.load(std::memory_order_relaxed));
290-
dataset_results->Statistics(stats.Dump());
275+
result->Statistics(stats.Dump());
291276

292-
return std::move(dataset_results);
277+
return result;
293278
}
294279

295280
DatasetPtr
@@ -300,14 +285,12 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
300285
int64_t limited_size) const {
301286
std::shared_lock read_lock(this->global_mutex_);
302287
auto computer = this->inner_codes_->FactoryComputer(query->GetFloat32Vectors());
303-
CHECK_ARGUMENT(limited_size != 0,
304-
fmt::format("limited_size({}) must not be equal to 0", limited_size));
288+
this->validate_range_args(query, radius, limited_size);
305289
if (limited_size < 0) {
306290
limited_size = std::numeric_limits<int64_t>::max();
307291
}
308292
if (total_count_ == 0) {
309-
auto [dataset_results, dists, ids] = create_fast_dataset(0, allocator_);
310-
return std::move(dataset_results);
293+
return make_empty_result();
311294
}
312295

313296
auto brute_force_params = BruteForceSearchParameters::FromJson(parameters);
@@ -352,14 +335,7 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
352335
}
353336
}
354337

355-
auto [dataset_results, dists, ids] =
356-
create_fast_dataset(static_cast<int64_t>(heap->Size()), allocator_);
357-
for (auto j = static_cast<int64_t>(heap->Size() - 1); j >= 0; --j) {
358-
dists[j] = heap->Top().first;
359-
ids[j] = this->label_table_->GetLabelById(heap->Top().second);
360-
heap->Pop();
361-
}
362-
return std::move(dataset_results);
338+
return this->pack_knn_result(heap);
363339
}
364340

365341
float
@@ -390,34 +366,30 @@ BruteForce::Serialize(StreamWriter& writer) const {
390366
this->label_table_->Serialize(writer);
391367

392368
// serialize footer (introduced since v0.15)
393-
auto metadata = std::make_shared<Metadata>();
394369
JsonType basic_info;
395370
basic_info["dim"].SetInt(dim_);
396371
basic_info["total_count"].SetInt(total_count_);
397372
basic_info[INDEX_PARAM].SetString(this->create_param_ptr_->ToString());
398-
metadata->Set("basic_info", basic_info);
399-
auto footer = std::make_shared<Footer>(metadata);
400-
footer->Write(writer);
373+
write_index_footer(writer, basic_info);
401374
}
402375

403376
void
404377
BruteForce::Deserialize(StreamReader& reader) {
405378
// try to deserialize footer (only in new version)
406-
auto footer = Footer::Parse(reader);
379+
JsonType basic_info;
380+
bool has_footer = read_index_footer(reader, basic_info);
407381

408382
BufferStreamReader buffer_reader(
409383
&reader, std::numeric_limits<uint64_t>::max(), this->allocator_);
410384

411-
if (footer == nullptr) { // old format, DON'T EDIT, remove in the future
385+
if (not has_footer) { // old format, DON'T EDIT, remove in the future
412386
logger::debug("parse with v0.13 version format");
413387

414388
StreamReader::ReadObj(buffer_reader, dim_);
415389
StreamReader::ReadObj(buffer_reader, total_count_);
416390
} else { // create like `else if ( ver in [v0.15, v0.17] )` here if need in the future
417391
logger::debug("parse with new version format");
418392

419-
auto metadata = footer->GetMetadata();
420-
auto basic_info = metadata->Get("basic_info");
421393
if (basic_info.Contains(INDEX_PARAM)) {
422394
std::string index_param_string = basic_info[INDEX_PARAM].GetString();
423395
auto index_param = std::make_shared<BruteForceParameter>();

src/algorithm/hgraph/hgraph_search.cpp

Lines changed: 9 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,7 @@ HGraph::KnnSearch(const DatasetPtr& query,
7373
if (GetNumElements() == 0) {
7474
return DatasetImpl::MakeEmptyDataset();
7575
}
76-
int64_t query_dim = query->GetDim();
77-
if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
78-
CHECK_ARGUMENT(
79-
query_dim == dim_,
80-
fmt::format("query.dim({}) must be equal to index.dim({})", query_dim, dim_));
81-
}
76+
this->validate_knn_args(query, k);
8277

8378
auto params = HGraphSearchParameters::FromJson(parameters);
8479
auto ef_search_threshold = std::max<int64_t>(AMPLIFICATION_FACTOR * k, 1000);
@@ -91,28 +86,9 @@ HGraph::KnnSearch(const DatasetPtr& query,
9186
force_remove_rlock = std::shared_lock<std::shared_mutex>(this->force_remove_mutex_);
9287
}
9388
std::shared_lock shared_lock(this->global_mutex_);
94-
// check k
95-
CHECK_ARGUMENT(k > 0, fmt::format("k({}) must be greater than 0", k));
9689
k = std::min(k, GetNumElements());
9790

98-
// check query vector
99-
CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only");
100-
101-
auto combined_filter = std::make_shared<CombinedFilter>();
102-
combined_filter->AppendFilter(this->label_table_->GetDeletedIdsFilter());
103-
if (filter != nullptr) {
104-
if (params.use_extra_info_filter) {
105-
combined_filter->AppendFilter(
106-
std::make_shared<ExtraInfoWrapperFilter>(filter, this->extra_infos_));
107-
} else {
108-
combined_filter->AppendFilter(
109-
std::make_shared<InnerIdWrapperFilter>(filter, *this->label_table_));
110-
}
111-
}
112-
FilterPtr ft = nullptr;
113-
if (not combined_filter->IsEmpty()) {
114-
ft = combined_filter;
115-
}
91+
FilterPtr ft = this->create_search_filter(filter, params.use_extra_info_filter);
11692

11793
if (iter_ctx == nullptr) {
11894
auto cur_count = this->total_count_.load();
@@ -376,32 +352,9 @@ HGraph::RangeSearch(const DatasetPtr& query,
376352
SearchStatistics stats;
377353
QueryContext ctx{.stats = &stats};
378354

379-
auto combined_filter = std::make_shared<CombinedFilter>();
380-
combined_filter->AppendFilter(this->label_table_->GetDeletedIdsFilter());
381-
if (filter != nullptr) {
382-
combined_filter->AppendFilter(
383-
std::make_shared<InnerIdWrapperFilter>(filter, *this->label_table_));
384-
}
385-
FilterPtr ft = nullptr;
386-
if (not combined_filter->IsEmpty()) {
387-
ft = combined_filter;
388-
}
389-
390-
int64_t query_dim = query->GetDim();
391-
if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
392-
CHECK_ARGUMENT(
393-
query_dim == dim_,
394-
fmt::format("query.dim({}) must be equal to index.dim({})", query_dim, dim_));
395-
}
396-
// check radius
397-
CHECK_ARGUMENT(radius >= 0, fmt::format("radius({}) must be greater equal than 0", radius))
398-
399-
// check query vector
400-
CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only");
355+
FilterPtr ft = this->create_search_filter(filter);
401356

402-
// check limited_size
403-
CHECK_ARGUMENT(limited_size != 0,
404-
fmt::format("limited_size({}) must not be equal to 0", limited_size));
357+
this->validate_range_args(query, radius, limited_size);
405358

406359
std::shared_lock<std::shared_mutex> force_remove_rlock;
407360
if (this->support_force_remove()) {
@@ -477,26 +430,9 @@ HGraph::RangeSearch(const DatasetPtr& query,
477430
}
478431
}
479432

480-
auto count = static_cast<const int64_t>(search_result->Size());
481-
auto [dataset_results, dists, ids] = create_fast_dataset(count, allocator_);
482-
char* extra_infos = nullptr;
483-
if (extra_info_size_ > 0) {
484-
extra_infos =
485-
static_cast<char*>(allocator_->Allocate(extra_info_size_ * search_result->Size()));
486-
dataset_results->ExtraInfos(extra_infos);
487-
}
488-
for (int64_t j = count - 1; j >= 0; --j) {
489-
dists[j] = search_result->Top().first;
490-
ids[j] = this->label_table_->GetLabelById(search_result->Top().second);
491-
if (extra_infos != nullptr) {
492-
this->extra_infos_->GetExtraInfoById(search_result->Top().second,
493-
extra_infos + extra_info_size_ * j);
494-
}
495-
search_result->Pop();
496-
}
497-
498-
dataset_results->Statistics(stats.Dump());
499-
return std::move(dataset_results);
433+
auto result = this->pack_knn_result_with_extra_info(search_result, allocator_);
434+
result->Statistics(stats.Dump());
435+
return result;
500436
}
501437

502438
[[nodiscard]] DatasetPtr
@@ -508,13 +444,8 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
508444
}
509445

510446
const auto& query = request.query_;
511-
int64_t query_dim = query->GetDim();
512447
auto k = request.topk_;
513-
if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
514-
CHECK_ARGUMENT(
515-
query_dim == dim_,
516-
fmt::format("query.dim({}) must be equal to index.dim({})", query_dim, dim_));
517-
}
448+
this->validate_knn_args(query, k);
518449

519450
auto params = HGraphSearchParameters::FromJson(request.params_str_);
520451

@@ -528,14 +459,8 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
528459
force_remove_rlock = std::shared_lock<std::shared_mutex>(this->force_remove_mutex_);
529460
}
530461
std::shared_lock shared_lock(this->global_mutex_);
531-
532-
// check k
533-
CHECK_ARGUMENT(k > 0, fmt::format("k({}) must be greater than 0", k));
534462
k = std::min(k, GetNumElements());
535463

536-
// check query vector
537-
CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only");
538-
539464
// Setup reasoning context if expected labels are provided.
540465
std::shared_ptr<ReasoningContext> reasoning_ctx;
541466
if (not request.expected_labels_.empty()) {
@@ -591,21 +516,7 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
591516
search_param.ep = result->Top().second;
592517
}
593518

594-
auto combined_filter = std::make_shared<CombinedFilter>();
595-
combined_filter->AppendFilter(this->label_table_->GetDeletedIdsFilter());
596-
if (request.filter_ != nullptr) {
597-
if (params.use_extra_info_filter) {
598-
combined_filter->AppendFilter(
599-
std::make_shared<ExtraInfoWrapperFilter>(request.filter_, this->extra_infos_));
600-
} else {
601-
combined_filter->AppendFilter(
602-
std::make_shared<InnerIdWrapperFilter>(request.filter_, *this->label_table_));
603-
}
604-
}
605-
FilterPtr ft = nullptr;
606-
if (not combined_filter->IsEmpty()) {
607-
ft = combined_filter;
608-
}
519+
FilterPtr ft = this->create_search_filter(request.filter_, params.use_extra_info_filter);
609520

610521
if (request.enable_attribute_filter_ and this->attr_filter_index_ != nullptr) {
611522
auto& schema = this->attr_filter_index_->field_type_map_;

0 commit comments

Comments
 (0)