@@ -73,7 +73,7 @@ BruteForce::Add(const DatasetPtr& data, AddMode mode) {
7373
7474 {
7575 std::lock_guard lock (this ->add_mutex_ );
76- if (this ->total_count_ == 0 ) {
76+ if (this ->total_count_ . load () == 0 ) {
7777 this ->Train (data);
7878 }
7979 }
@@ -88,9 +88,9 @@ BruteForce::Add(const DatasetPtr& data, AddMode mode) {
8888 if (this ->label_table_ ->CheckLabel (label)) {
8989 return label;
9090 }
91- inner_id = this ->total_count_ ;
92- this ->total_count_ ++ ;
93- this ->resize (total_count_);
91+ inner_id = this ->total_count_ . load () ;
92+ ++ this ->total_count_ ;
93+ this ->resize (total_count_. load () );
9494 this ->label_table_ ->Insert (inner_id, label);
9595 }
9696 std::shared_lock global_lock (this ->global_mutex_ );
@@ -161,7 +161,7 @@ BruteForce::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
161161
162162 std::scoped_lock lock (this ->add_mutex_ , this ->label_lookup_mutex_ );
163163 for (auto label : ids) {
164- const auto last_inner_id = static_cast <InnerIdType>(this ->total_count_ - 1 );
164+ const auto last_inner_id = static_cast <InnerIdType>(this ->total_count_ . load () - 1 );
165165 const auto inner_id = this ->label_table_ ->GetIdByLabel (label);
166166
167167 CHECK_ARGUMENT (inner_id <= last_inner_id, " the element to be remove is invalid" );
@@ -181,7 +181,7 @@ BruteForce::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
181181 this ->label_table_ ->Insert (inner_id, last_label);
182182 }
183183
184- this ->total_count_ -- ;
184+ -- this ->total_count_ ;
185185 }
186186 return 1 ;
187187}
@@ -247,15 +247,16 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
247247 dist_cmp.fetch_add (dist_cmp_local, std::memory_order_relaxed);
248248 };
249249
250+ auto count = total_count_.load ();
250251 if (parallel_count == 1 || this ->thread_pool_ == nullptr ) {
251- search_func (0 , total_count_ , heaps[0 ]);
252+ search_func (0 , count , heaps[0 ]);
252253 heap = heaps[0 ];
253254 } else {
254255 std::vector<std::future<void >> futures;
255- auto chunk_size = (total_count_ + parallel_count - 1 ) / parallel_count;
256+ auto chunk_size = (count + parallel_count - 1 ) / parallel_count;
256257 for (auto i = 0 ; i < parallel_count; ++i) {
257258 auto start = i * chunk_size;
258- auto end = std::min (start + chunk_size, total_count_ );
259+ auto end = std::min (start + chunk_size, count );
259260 auto future = this ->thread_pool_ ->GeneralEnqueue (search_func, start, end, heaps[i]);
260261 futures.emplace_back (std::move (future));
261262 }
@@ -289,7 +290,7 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
289290 if (limited_size < 0 ) {
290291 limited_size = std::numeric_limits<int64_t >::max ();
291292 }
292- if (total_count_ == 0 ) {
293+ if (total_count_. load () == 0 ) {
293294 return make_empty_result ();
294295 }
295296
@@ -312,16 +313,17 @@ BruteForce::RangeSearch(const vsag::DatasetPtr& query,
312313 };
313314
314315 DistHeapPtr heap = nullptr ;
315- parallel_count = std::min (parallel_count, total_count_);
316+ auto count = total_count_.load ();
317+ parallel_count = std::min (parallel_count, count);
316318 if (parallel_count <= 1 or this ->thread_pool_ == nullptr ) {
317- heap = search_func (0 , total_count_ );
319+ heap = search_func (0 , count );
318320 } else {
319321 std::vector<std::future<DistHeapPtr>> futures;
320322 futures.reserve (parallel_count);
321- auto chunk_size = (total_count_ + parallel_count - 1 ) / parallel_count;
323+ auto chunk_size = (count + parallel_count - 1 ) / parallel_count;
322324 for (uint64_t i = 0 ; i < parallel_count; ++i) {
323325 auto start = static_cast <InnerIdType>(i * chunk_size);
324- auto end = static_cast <InnerIdType>(std::min (start + chunk_size, total_count_ ));
326+ auto end = static_cast <InnerIdType>(std::min (start + chunk_size, count ));
325327 futures.emplace_back (this ->thread_pool_ ->GeneralEnqueue (search_func, start, end));
326328 }
327329
@@ -368,7 +370,7 @@ BruteForce::Serialize(StreamWriter& writer) const {
368370 // serialize footer (introduced since v0.15)
369371 JsonType basic_info;
370372 basic_info[" dim" ].SetInt (dim_);
371- basic_info[" total_count" ].SetInt (total_count_);
373+ basic_info[" total_count" ].SetInt (total_count_. load () );
372374 basic_info[INDEX_PARAM].SetString (this ->create_param_ptr_ ->ToString ());
373375 write_index_footer (writer, basic_info);
374376}
@@ -386,7 +388,9 @@ BruteForce::Deserialize(StreamReader& reader) {
386388 logger::debug (" parse with v0.13 version format" );
387389
388390 StreamReader::ReadObj (buffer_reader, dim_);
389- StreamReader::ReadObj (buffer_reader, total_count_);
391+ uint64_t count = 0 ;
392+ StreamReader::ReadObj (buffer_reader, count);
393+ total_count_.store (count);
390394 } else { // create like `else if ( ver in [v0.15, v0.17] )` here if need in the future
391395 logger::debug (" parse with new version format" );
392396
@@ -404,7 +408,7 @@ BruteForce::Deserialize(StreamReader& reader) {
404408 }
405409 }
406410 dim_ = basic_info[" dim" ].GetInt ();
407- total_count_ = basic_info[" total_count" ].GetInt ();
411+ total_count_. store ( basic_info[" total_count" ].GetInt () );
408412
409413 if (this ->use_attribute_filter_ and this ->attr_filter_index_ != nullptr ) {
410414 this ->attr_filter_index_ ->Deserialize (buffer_reader);
0 commit comments