@@ -73,12 +73,7 @@ HGraph::KnnSearch(const DatasetPtr& query,
7373 if (GetNumElements () == 0 ) {
7474 return DatasetImpl::MakeEmptyDataset ();
7575 }
76- int64_t query_dim = query->GetDim ();
77- if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
78- CHECK_ARGUMENT (
79- query_dim == dim_,
80- fmt::format (" query.dim({}) must be equal to index.dim({})" , query_dim, dim_));
81- }
76+ this ->validate_knn_args (query, k);
8277
8378 auto params = HGraphSearchParameters::FromJson (parameters);
8479 auto ef_search_threshold = std::max<int64_t >(AMPLIFICATION_FACTOR * k, 1000 );
@@ -91,28 +86,9 @@ HGraph::KnnSearch(const DatasetPtr& query,
9186 force_remove_rlock = std::shared_lock<std::shared_mutex>(this ->force_remove_mutex_ );
9287 }
9388 std::shared_lock shared_lock (this ->global_mutex_ );
94- // check k
95- CHECK_ARGUMENT (k > 0 , fmt::format (" k({}) must be greater than 0" , k));
9689 k = std::min (k, GetNumElements ());
9790
98- // check query vector
99- CHECK_ARGUMENT (query->GetNumElements () == 1 , " query dataset should contain 1 vector only" );
100-
101- auto combined_filter = std::make_shared<CombinedFilter>();
102- combined_filter->AppendFilter (this ->label_table_ ->GetDeletedIdsFilter ());
103- if (filter != nullptr ) {
104- if (params.use_extra_info_filter ) {
105- combined_filter->AppendFilter (
106- std::make_shared<ExtraInfoWrapperFilter>(filter, this ->extra_infos_ ));
107- } else {
108- combined_filter->AppendFilter (
109- std::make_shared<InnerIdWrapperFilter>(filter, *this ->label_table_ ));
110- }
111- }
112- FilterPtr ft = nullptr ;
113- if (not combined_filter->IsEmpty ()) {
114- ft = combined_filter;
115- }
91+ FilterPtr ft = this ->create_search_filter (filter, params.use_extra_info_filter );
11692
11793 if (iter_ctx == nullptr ) {
11894 auto cur_count = this ->total_count_ .load ();
@@ -376,32 +352,9 @@ HGraph::RangeSearch(const DatasetPtr& query,
376352 SearchStatistics stats;
377353 QueryContext ctx{.stats = &stats};
378354
379- auto combined_filter = std::make_shared<CombinedFilter>();
380- combined_filter->AppendFilter (this ->label_table_ ->GetDeletedIdsFilter ());
381- if (filter != nullptr ) {
382- combined_filter->AppendFilter (
383- std::make_shared<InnerIdWrapperFilter>(filter, *this ->label_table_ ));
384- }
385- FilterPtr ft = nullptr ;
386- if (not combined_filter->IsEmpty ()) {
387- ft = combined_filter;
388- }
389-
390- int64_t query_dim = query->GetDim ();
391- if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
392- CHECK_ARGUMENT (
393- query_dim == dim_,
394- fmt::format (" query.dim({}) must be equal to index.dim({})" , query_dim, dim_));
395- }
396- // check radius
397- CHECK_ARGUMENT (radius >= 0 , fmt::format (" radius({}) must be greater equal than 0" , radius))
398-
399- // check query vector
400- CHECK_ARGUMENT (query->GetNumElements () == 1 , " query dataset should contain 1 vector only" );
355+ FilterPtr ft = this ->create_search_filter (filter);
401356
402- // check limited_size
403- CHECK_ARGUMENT (limited_size != 0 ,
404- fmt::format (" limited_size({}) must not be equal to 0" , limited_size));
357+ this ->validate_range_args (query, radius, limited_size);
405358
406359 std::shared_lock<std::shared_mutex> force_remove_rlock;
407360 if (this ->support_force_remove ()) {
@@ -477,26 +430,9 @@ HGraph::RangeSearch(const DatasetPtr& query,
477430 }
478431 }
479432
480- auto count = static_cast <const int64_t >(search_result->Size ());
481- auto [dataset_results, dists, ids] = create_fast_dataset (count, allocator_);
482- char * extra_infos = nullptr ;
483- if (extra_info_size_ > 0 ) {
484- extra_infos =
485- static_cast <char *>(allocator_->Allocate (extra_info_size_ * search_result->Size ()));
486- dataset_results->ExtraInfos (extra_infos);
487- }
488- for (int64_t j = count - 1 ; j >= 0 ; --j) {
489- dists[j] = search_result->Top ().first ;
490- ids[j] = this ->label_table_ ->GetLabelById (search_result->Top ().second );
491- if (extra_infos != nullptr ) {
492- this ->extra_infos_ ->GetExtraInfoById (search_result->Top ().second ,
493- extra_infos + extra_info_size_ * j);
494- }
495- search_result->Pop ();
496- }
497-
498- dataset_results->Statistics (stats.Dump ());
499- return std::move (dataset_results);
433+ auto result = this ->pack_knn_result_with_extra_info (search_result, allocator_);
434+ result->Statistics (stats.Dump ());
435+ return result;
500436}
501437
502438[[nodiscard]] DatasetPtr
@@ -508,13 +444,8 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
508444 }
509445
510446 const auto & query = request.query_ ;
511- int64_t query_dim = query->GetDim ();
512447 auto k = request.topk_ ;
513- if (data_type_ != DataTypes::DATA_TYPE_SPARSE) {
514- CHECK_ARGUMENT (
515- query_dim == dim_,
516- fmt::format (" query.dim({}) must be equal to index.dim({})" , query_dim, dim_));
517- }
448+ this ->validate_knn_args (query, k);
518449
519450 auto params = HGraphSearchParameters::FromJson (request.params_str_ );
520451
@@ -528,14 +459,8 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
528459 force_remove_rlock = std::shared_lock<std::shared_mutex>(this ->force_remove_mutex_ );
529460 }
530461 std::shared_lock shared_lock (this ->global_mutex_ );
531-
532- // check k
533- CHECK_ARGUMENT (k > 0 , fmt::format (" k({}) must be greater than 0" , k));
534462 k = std::min (k, GetNumElements ());
535463
536- // check query vector
537- CHECK_ARGUMENT (query->GetNumElements () == 1 , " query dataset should contain 1 vector only" );
538-
539464 // Setup reasoning context if expected labels are provided.
540465 std::shared_ptr<ReasoningContext> reasoning_ctx;
541466 if (not request.expected_labels_ .empty ()) {
@@ -591,21 +516,7 @@ HGraph::SearchWithRequest(const SearchRequest& request) const {
591516 search_param.ep = result->Top ().second ;
592517 }
593518
594- auto combined_filter = std::make_shared<CombinedFilter>();
595- combined_filter->AppendFilter (this ->label_table_ ->GetDeletedIdsFilter ());
596- if (request.filter_ != nullptr ) {
597- if (params.use_extra_info_filter ) {
598- combined_filter->AppendFilter (
599- std::make_shared<ExtraInfoWrapperFilter>(request.filter_ , this ->extra_infos_ ));
600- } else {
601- combined_filter->AppendFilter (
602- std::make_shared<InnerIdWrapperFilter>(request.filter_ , *this ->label_table_ ));
603- }
604- }
605- FilterPtr ft = nullptr ;
606- if (not combined_filter->IsEmpty ()) {
607- ft = combined_filter;
608- }
519+ FilterPtr ft = this ->create_search_filter (request.filter_ , params.use_extra_info_filter );
609520
610521 if (request.enable_attribute_filter_ and this ->attr_filter_index_ != nullptr ) {
611522 auto & schema = this ->attr_filter_index_ ->field_type_map_ ;
0 commit comments