Skip to content

Commit 809877b

Browse files
committed
refactor: extract total_count_ to InnerIndexInterface base class
Move total_count_ from HGraph, BruteForce, and WARP into the InnerIndexInterface base class as std::atomic<uint64_t>. This eliminates code duplication, ensures consistent thread-safe access across all index implementations, and unifies the element count management. Signed-off-by: LHT129 <tianlan.lht@antgroup.com> Assisted-by: ClaudeCode:claude-opus-4-6
1 parent 390e430 commit 809877b

6 files changed

Lines changed: 47 additions & 41 deletions

File tree

src/algorithm/bruteforce/bruteforce.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ BruteForce::Add(const DatasetPtr& data, AddMode mode) {
7373

7474
{
7575
std::lock_guard lock(this->add_mutex_);
76-
if (this->total_count_ == 0) {
76+
if (this->total_count_.load() == 0) {
7777
this->Train(data);
7878
}
7979
}
@@ -88,9 +88,9 @@ BruteForce::Add(const DatasetPtr& data, AddMode mode) {
8888
if (this->label_table_->CheckLabel(label)) {
8989
return label;
9090
}
91-
inner_id = this->total_count_;
92-
this->total_count_++;
93-
this->resize(total_count_);
91+
inner_id = this->total_count_.load();
92+
++this->total_count_;
93+
this->resize(total_count_.load());
9494
this->label_table_->Insert(inner_id, label);
9595
}
9696
std::shared_lock global_lock(this->global_mutex_);
@@ -161,7 +161,7 @@ BruteForce::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
161161

162162
std::scoped_lock lock(this->add_mutex_, this->label_lookup_mutex_);
163163
for (auto label : ids) {
164-
const auto last_inner_id = static_cast<InnerIdType>(this->total_count_ - 1);
164+
const auto last_inner_id = static_cast<InnerIdType>(this->total_count_.load() - 1);
165165
const auto inner_id = this->label_table_->GetIdByLabel(label);
166166

167167
CHECK_ARGUMENT(inner_id <= last_inner_id, "the element to be remove is invalid");
@@ -181,7 +181,7 @@ BruteForce::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
181181
this->label_table_->Insert(inner_id, last_label);
182182
}
183183

184-
this->total_count_--;
184+
--this->total_count_;
185185
}
186186
return 1;
187187
}
@@ -247,15 +247,16 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
247247
dist_cmp.fetch_add(dist_cmp_local, std::memory_order_relaxed);
248248
};
249249

250+
auto count = total_count_.load();
250251
if (parallel_count == 1 || this->thread_pool_ == nullptr) {
251-
search_func(0, total_count_, heaps[0]);
252+
search_func(0, count, heaps[0]);
252253
heap = heaps[0];
253254
} else {
254255
std::vector<std::future<void>> futures;
255-
auto chunk_size = (total_count_ + parallel_count - 1) / parallel_count;
256+
auto chunk_size = (count + parallel_count - 1) / parallel_count;
256257
for (auto i = 0; i < parallel_count; ++i) {
257258
auto start = i * chunk_size;
258-
auto end = std::min(start + chunk_size, total_count_);
259+
auto end = std::min(start + chunk_size, count);
259260
auto future = this->thread_pool_->GeneralEnqueue(search_func, start, end, heaps[i]);
260261
futures.emplace_back(std::move(future));
261262
}
@@ -289,7 +290,7 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
289290
if (limited_size < 0) {
290291
limited_size = std::numeric_limits<int64_t>::max();
291292
}
292-
if (total_count_ == 0) {
293+
if (total_count_.load() == 0) {
293294
return make_empty_result();
294295
}
295296

@@ -312,16 +313,17 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
312313
};
313314

314315
DistHeapPtr heap = nullptr;
315-
parallel_count = std::min(parallel_count, total_count_);
316+
auto count = total_count_.load();
317+
parallel_count = std::min(parallel_count, count);
316318
if (parallel_count <= 1 or this->thread_pool_ == nullptr) {
317-
heap = search_func(0, total_count_);
319+
heap = search_func(0, count);
318320
} else {
319321
std::vector<std::future<DistHeapPtr>> futures;
320322
futures.reserve(parallel_count);
321-
auto chunk_size = (total_count_ + parallel_count - 1) / parallel_count;
323+
auto chunk_size = (count + parallel_count - 1) / parallel_count;
322324
for (uint64_t i = 0; i < parallel_count; ++i) {
323325
auto start = static_cast<InnerIdType>(i * chunk_size);
324-
auto end = static_cast<InnerIdType>(std::min(start + chunk_size, total_count_));
326+
auto end = static_cast<InnerIdType>(std::min(start + chunk_size, count));
325327
futures.emplace_back(this->thread_pool_->GeneralEnqueue(search_func, start, end));
326328
}
327329

@@ -368,7 +370,7 @@ BruteForce::Serialize(StreamWriter& writer) const {
368370
// serialize footer (introduced since v0.15)
369371
JsonType basic_info;
370372
basic_info["dim"].SetInt(dim_);
371-
basic_info["total_count"].SetInt(total_count_);
373+
basic_info["total_count"].SetInt(total_count_.load());
372374
basic_info[INDEX_PARAM].SetString(this->create_param_ptr_->ToString());
373375
write_index_footer(writer, basic_info);
374376
}
@@ -386,7 +388,9 @@ BruteForce::Deserialize(StreamReader& reader) {
386388
logger::debug("parse with v0.13 version format");
387389

388390
StreamReader::ReadObj(buffer_reader, dim_);
389-
StreamReader::ReadObj(buffer_reader, total_count_);
391+
uint64_t count = 0;
392+
StreamReader::ReadObj(buffer_reader, count);
393+
total_count_.store(count);
390394
} else { // create like `else if ( ver in [v0.15, v0.17] )` here if need in the future
391395
logger::debug("parse with new version format");
392396

@@ -404,7 +408,7 @@ BruteForce::Deserialize(StreamReader& reader) {
404408
}
405409
}
406410
dim_ = basic_info["dim"].GetInt();
407-
total_count_ = basic_info["total_count"].GetInt();
411+
total_count_.store(basic_info["total_count"].GetInt());
408412

409413
if (this->use_attribute_filter_ and this->attr_filter_index_ != nullptr) {
410414
this->attr_filter_index_->Deserialize(buffer_reader);

src/algorithm/bruteforce/bruteforce.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ class BruteForce : public InnerIndexInterface {
8080

8181
[[nodiscard]] int64_t
8282
GetNumElements() const override {
83-
return this->total_count_ - this->delete_count_;
83+
return static_cast<int64_t>(this->total_count_.load()) -
84+
static_cast<int64_t>(this->delete_count_);
8485
}
8586

8687
[[nodiscard]] int64_t
@@ -143,8 +144,6 @@ class BruteForce : public InnerIndexInterface {
143144
private:
144145
FlattenInterfacePtr inner_codes_{nullptr};
145146

146-
uint64_t total_count_{0};
147-
148147
uint64_t delete_count_{0};
149148

150149
uint64_t resize_increase_count_bit_{DEFAULT_RESIZE_BIT};

src/algorithm/hgraph/hgraph.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,6 @@ class HGraph : public InnerIndexInterface {
510510
uint64_t ef_construct_{400};
511511
float alpha_{1.0};
512512

513-
std::atomic<uint64_t> total_count_{0};
514-
515513
std::shared_ptr<VisitedListPool> pool_{nullptr};
516514

517515
mutable std::shared_mutex global_mutex_;

src/algorithm/inner_index_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ class InnerIndexInterface {
570570
bool immutable_{false};
571571

572572
protected:
573+
std::atomic<uint64_t> total_count_{0};
574+
573575
std::atomic<int64_t> current_memory_usage_{0};
574576
mutable std::shared_mutex memory_usage_mutex_{};
575577

src/algorithm/warp/warp.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ WARP::Add(const DatasetPtr& data, AddMode mode) {
110110

111111
{
112112
std::lock_guard lock(this->add_mutex_);
113-
if (this->total_count_ == 0) {
113+
if (this->total_count_.load() == 0) {
114114
this->Train(data);
115115
}
116116
}
@@ -150,9 +150,9 @@ WARP::Add(const DatasetPtr& data, AddMode mode) {
150150
if (this->label_table_->CheckLabel(label)) {
151151
return label;
152152
}
153-
inner_id = this->total_count_;
154-
this->total_count_++;
155-
this->resize(total_count_);
153+
inner_id = this->total_count_.load();
154+
++this->total_count_;
155+
this->resize(total_count_.load());
156156
this->label_table_->Insert(inner_id, label);
157157
}
158158
std::shared_lock global_lock(this->global_mutex_);
@@ -306,14 +306,14 @@ WARP::SearchWithRequest(const SearchRequest& request) const {
306306
DistHeapPtr heap = nullptr;
307307

308308
if (parallel_count == 1 || this->thread_pool_ == nullptr ||
309-
total_count_ < MIN_PARALLEL_SEARCH_DOC_COUNT) {
310-
heap = search_func(0, total_count_);
309+
total_count_.load() < MIN_PARALLEL_SEARCH_DOC_COUNT) {
310+
heap = search_func(0, total_count_.load());
311311
} else {
312312
std::vector<std::future<DistHeapPtr>> futures;
313-
auto chunk_size = (total_count_ + parallel_count - 1) / parallel_count;
313+
auto chunk_size = (total_count_.load() + parallel_count - 1) / parallel_count;
314314
for (auto i = 0; i < static_cast<int>(parallel_count); ++i) {
315315
auto start = i * chunk_size;
316-
auto end = std::min(start + chunk_size, static_cast<uint64_t>(total_count_));
316+
auto end = std::min(start + chunk_size, total_count_.load());
317317
if (start < end) {
318318
auto future = this->thread_pool_->GeneralEnqueue(search_func, start, end);
319319
futures.emplace_back(std::move(future));
@@ -376,11 +376,11 @@ WARP::RangeSearch(const vsag::DatasetPtr& query,
376376

377377
// Use serial version if no thread pool or small dataset
378378
if (parallel_count == 1 || this->thread_pool_ == nullptr ||
379-
total_count_ < MIN_PARALLEL_SEARCH_DOC_COUNT) {
379+
total_count_.load() < MIN_PARALLEL_SEARCH_DOC_COUNT) {
380380
DistHeapPtr heap =
381381
std::make_shared<StandardHeap<true, true>>(this->allocator_, limited_size);
382382

383-
for (InnerIdType doc_id = 0; doc_id < total_count_; ++doc_id) {
383+
for (InnerIdType doc_id = 0; doc_id < total_count_.load(); ++doc_id) {
384384
if (filter != nullptr and
385385
not filter->CheckValid(this->label_table_->GetLabelById(doc_id))) {
386386
continue;
@@ -410,12 +410,13 @@ WARP::RangeSearch(const vsag::DatasetPtr& query,
410410
auto future = this->thread_pool_->GeneralEnqueue([&]() {
411411
std::vector<std::pair<float, InnerIdType>> local_results;
412412
// Pre-allocate to avoid frequent reallocations
413-
local_results.reserve(std::min(static_cast<size_t>(1024),
414-
static_cast<size_t>(total_count_) / parallel_count));
413+
local_results.reserve(
414+
std::min(static_cast<size_t>(1024),
415+
static_cast<size_t>(total_count_.load()) / parallel_count));
415416

416417
while (true) {
417418
auto doc_id = next_doc.fetch_add(1);
418-
if (doc_id >= total_count_) {
419+
if (doc_id >= total_count_.load()) {
419420
break;
420421
}
421422

@@ -474,7 +475,8 @@ WARP::RangeSearch(const vsag::DatasetPtr& query,
474475
void
475476
WARP::Serialize(StreamWriter& writer) const {
476477
// Serialize document offsets (size = total_count_ + 1)
477-
StreamWriter::WriteObj(writer, total_count_);
478+
uint64_t count = total_count_.load();
479+
StreamWriter::WriteObj(writer, count);
478480
StreamWriter::WriteObj(writer, total_vector_count_);
479481

480482
// Batch write the doc_offsets array using WriteVector
@@ -486,7 +488,7 @@ WARP::Serialize(StreamWriter& writer) const {
486488
// Serialize footer
487489
JsonType basic_info;
488490
basic_info["dim"].SetInt(dim_);
489-
basic_info["total_count"].SetInt(total_count_);
491+
basic_info["total_count"].SetInt(total_count_.load());
490492
basic_info["total_vector_count"].SetInt(total_vector_count_);
491493
basic_info[INDEX_PARAM].SetString(this->create_param_ptr_->ToString());
492494
write_index_footer(writer, basic_info);
@@ -516,7 +518,9 @@ WARP::Deserialize(StreamReader& reader) {
516518
}
517519
dim_ = basic_info["dim"].GetInt();
518520

519-
StreamReader::ReadObj(buffer_reader, total_count_);
521+
uint64_t count = 0;
522+
StreamReader::ReadObj(buffer_reader, count);
523+
total_count_.store(count);
520524
StreamReader::ReadObj(buffer_reader, total_vector_count_);
521525

522526
// Batch read the doc_offsets array using ReadVector

src/algorithm/warp/warp.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ class WARP : public InnerIndexInterface {
7373

7474
[[nodiscard]] int64_t
7575
GetNumElements() const override {
76-
return this->total_count_ - this->delete_count_;
76+
return static_cast<int64_t>(this->total_count_.load()) -
77+
static_cast<int64_t>(this->delete_count_);
7778
}
7879

7980
[[nodiscard]] int64_t
@@ -132,8 +133,6 @@ class WARP : public InnerIndexInterface {
132133
private:
133134
FlattenInterfacePtr inner_codes_{nullptr};
134135

135-
uint64_t total_count_{0}; // Number of documents (not vectors)
136-
137136
uint64_t delete_count_{0};
138137

139138
uint64_t total_vector_count_{0}; // Total number of vectors across all docs

0 commit comments

Comments
 (0)