Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/algorithm/pyramid/pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,23 @@ Pyramid::search_impl(const DatasetPtr& query,

int64_t
Pyramid::GetNumElements() const {
return base_codes_->TotalCount();
return base_codes_->TotalCount() - delete_count_;
}

int64_t
Pyramid::GetNumberRemoved() const {
return delete_count_;
}

uint32_t
Pyramid::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
if (mode != RemoveMode::MARK_REMOVE) {
throw VsagException(ErrorType::INVALID_ARGUMENT, "Pyramid only supports MARK_REMOVE");
}
std::scoped_lock lock(this->label_lookup_mutex_, this->cur_element_count_mutex_);
uint32_t delete_count = this->label_table_->MarkRemove(ids);
delete_count_ += delete_count;
return delete_count;
Comment thread
LHT129 marked this conversation as resolved.
}

void
Expand Down Expand Up @@ -442,6 +458,7 @@ Pyramid::Deserialize(StreamReader& reader) {
&reader, std::numeric_limits<uint64_t>::max(), this->allocator_);

label_table_->Deserialize(buffer_reader);
delete_count_ = static_cast<int64_t>(label_table_->GetAllDeletedIds().size());
base_codes_->Deserialize(buffer_reader);
if (use_reorder_) {
precise_codes_->Deserialize(buffer_reader);
Expand Down Expand Up @@ -613,6 +630,8 @@ Pyramid::InitFeatures() {
IndexFeature::SUPPORT_EXPORT_MODEL,
IndexFeature::SUPPORT_GET_MEMORY_USAGE,
});

this->index_feature_list_->SetFeatures({IndexFeature::SUPPORT_DELETE_BY_ID});
}

static const std::string HGRAPH_PARAMS_TEMPLATE =
Expand Down
7 changes: 7 additions & 0 deletions src/algorithm/pyramid/pyramid.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ class Pyramid : public InnerIndexInterface {
int64_t
GetNumElements() const override;

int64_t
GetNumberRemoved() const override;

uint32_t
Remove(const std::vector<int64_t>& ids, RemoveMode mode) override;

Comment thread
LHT129 marked this conversation as resolved.
std::string
GetStats() const override;

Expand Down Expand Up @@ -249,6 +255,7 @@ class Pyramid : public InnerIndexInterface {
std::unique_ptr<BasicSearcher> searcher_ = nullptr;
int64_t max_capacity_{0};
int64_t cur_element_count_{0};
std::atomic<int64_t> delete_count_{0};
Comment thread
LHT129 marked this conversation as resolved.
float alpha_{1.0F};
bool support_duplicate_{false};

Expand Down
12 changes: 12 additions & 0 deletions src/algorithm/sindi/sindi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ SINDI::Deserialize(StreamReader& reader) {
}

label_table_->Deserialize(reader_ref);
delete_count_ = static_cast<int64_t>(label_table_->GetAllDeletedIds().size());

if (use_reorder_) {
rerank_flat_index_->Deserialize(reader_ref);
Expand Down Expand Up @@ -750,6 +751,7 @@ SINDI::InitFeatures() {
IndexFeature::SUPPORT_UPDATE_VECTOR_CONCURRENT});

// metric
this->index_feature_list_->SetFeatures({IndexFeature::SUPPORT_DELETE_BY_ID});
this->index_feature_list_->SetFeature(IndexFeature::SUPPORT_METRIC_TYPE_INNER_PRODUCT);
}

Expand Down Expand Up @@ -819,5 +821,15 @@ SINDI::remap_sparse_vector_for_build(const SparseVector& input, Vector<uint32_t>
remapped.vals_ = input.vals_;
return remapped;
}
uint32_t
SINDI::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
if (mode != RemoveMode::MARK_REMOVE) {
throw VsagException(ErrorType::INVALID_ARGUMENT, "SINDI only supports MARK_REMOVE");
}
std::scoped_lock lock(this->global_mutex_);
uint32_t delete_count = this->label_table_->MarkRemove(ids);
delete_count_ += delete_count;
return delete_count;
Comment thread
LHT129 marked this conversation as resolved.
Comment on lines +825 to +832
}

} // namespace vsag
12 changes: 11 additions & 1 deletion src/algorithm/sindi/sindi.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,17 @@ class SINDI : public InnerIndexInterface {

int64_t
GetNumElements() const override {
return cur_element_count_;
return cur_element_count_ - delete_count_;
}
Comment on lines 110 to 113

int64_t
GetNumberRemoved() const override {
return delete_count_;
}

uint32_t
Remove(const std::vector<int64_t>& ids, RemoveMode mode) override;

Comment thread
LHT129 marked this conversation as resolved.
[[nodiscard]] uint64_t
EstimateMemory(uint64_t num_elements) const override;

Expand Down Expand Up @@ -166,6 +174,8 @@ class SINDI : public InnerIndexInterface {

int64_t cur_element_count_{0};

std::atomic<int64_t> delete_count_{0};
Comment thread
LHT129 marked this conversation as resolved.

bool use_reorder_{false};

bool use_quantization_{false};
Expand Down
12 changes: 12 additions & 0 deletions src/algorithm/warp/warp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ WARP::Deserialize(StreamReader& reader) {

this->inner_codes_->Deserialize(buffer_reader);
this->label_table_->Deserialize(buffer_reader);
delete_count_ = label_table_->GetAllDeletedIds().size();
this->cal_memory_usage();
}

Expand Down Expand Up @@ -644,6 +645,17 @@ WARP::InitFeatures() {
});
}

uint32_t
WARP::Remove(const std::vector<int64_t>& ids, RemoveMode mode) {
if (mode != RemoveMode::MARK_REMOVE) {
throw VsagException(ErrorType::INVALID_ARGUMENT, "WARP only supports MARK_REMOVE");
}
std::scoped_lock label_lock(this->label_lookup_mutex_);
uint32_t delete_count = this->label_table_->MarkRemove(ids);
delete_count_ += delete_count;
Comment thread
LHT129 marked this conversation as resolved.
return delete_count;
}
Comment thread
LHT129 marked this conversation as resolved.
Comment thread
LHT129 marked this conversation as resolved.

static const std::string WARP_PARAMS_TEMPLATE =
R"(
{
Expand Down
5 changes: 4 additions & 1 deletion src/algorithm/warp/warp.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class WARP : public InnerIndexInterface {
return this->delete_count_;
}

uint32_t
Remove(const std::vector<int64_t>& ids, RemoveMode mode) override;

void
GetVectorByInnerId(InnerIdType inner_id, float* data) const override;

Expand Down Expand Up @@ -134,7 +137,7 @@ class WARP : public InnerIndexInterface {

uint64_t total_count_{0}; // Number of documents (not vectors)

uint64_t delete_count_{0};
std::atomic<uint64_t> delete_count_{0};

uint64_t total_vector_count_{0}; // Total number of vectors across all docs

Expand Down
50 changes: 50 additions & 0 deletions tests/test_pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,53 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex,
REQUIRE(stats.contains("subindex_quality"));
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex,
"Pyramid Mark Remove",
"[ft][remove][pyramid]") {
auto metric_type = GENERATE("l2");
PyramidParam pyramid_param;
pyramid_param.no_build_levels = {0, 1, 2};
const std::string name = "pyramid";
auto search_param = GeneratePyramidSearchParametersString(200);
for (auto& dim : dims) {
INFO(fmt::format("metric_type={}, dim={}", metric_type, dim));
auto param = GeneratePyramidBuildParametersString(metric_type, dim, pyramid_param);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type, /*with_path=*/true);
TestBuildIndex(index, dataset, true);

auto base_num = dataset->base_->GetNumElements();
const auto* ids = dataset->base_->GetIds();
REQUIRE(index->GetNumElements() == base_num);
REQUIRE(index->GetNumberRemoved() == 0);

// FORCE_REMOVE is not supported by Pyramid
auto force_result = index->Remove(ids[0], vsag::RemoveMode::FORCE_REMOVE);
REQUIRE_FALSE(force_result.has_value());

// mark remove half of the base data
int64_t remove_count = base_num / 2;
std::vector<int64_t> remove_ids(ids, ids + remove_count);
auto remove_result = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(remove_result.has_value());
REQUIRE(remove_result.value() == remove_count);
REQUIRE(index->GetNumElements() == base_num - remove_count);
REQUIRE(index->GetNumberRemoved() == remove_count);

// removing the same ids again should remove nothing
auto duplicate_remove = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(duplicate_remove.has_value());
REQUIRE(duplicate_remove.value() == 0);

// removed ids must not appear in search results
for (int64_t i = 0; i < remove_count; ++i) {
auto query = fixtures::get_one_query(dataset->base_, static_cast<int>(i));
auto search_result = index->KnnSearch(query, 10, search_param);
REQUIRE(search_result.has_value());
for (int64_t j = 0; j < search_result.value()->GetDim(); ++j) {
REQUIRE(search_result.value()->GetIds()[j] != ids[i]);
}
}
}
}
43 changes: 43 additions & 0 deletions tests/test_sindi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,46 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::SINDITestIndex,
TestRangeSearch(index, dataset, search_param, 0.99, 10, true);
TestFilterSearch(index, dataset, search_param, 0.99, true);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::SINDITestIndex, "SINDI Mark Remove", "[ft][remove][sindi]") {
fixtures::SINDIParam param;
param.use_reorder = GENERATE(true, false);
auto build_param = fixtures::SINDITestIndex::GenerateBuildParameter(param);
auto index = TestFactory("sindi", build_param, true);
auto dataset = pool.GetSparseDatasetAndCreate(base_count, 128, 0.8);
REQUIRE(index->GetIndexType() == vsag::IndexType::SINDI);
TestBuildIndex(index, dataset, true);

auto base_num = dataset->base_->GetNumElements();
const auto* ids = dataset->base_->GetIds();
REQUIRE(index->GetNumElements() == base_num);
REQUIRE(index->GetNumberRemoved() == 0);

// FORCE_REMOVE is not supported by SINDI
auto force_result = index->Remove(ids[0], vsag::RemoveMode::FORCE_REMOVE);
REQUIRE_FALSE(force_result.has_value());

// mark remove half of the base data
int64_t remove_count = base_num / 2;
std::vector<int64_t> remove_ids(ids, ids + remove_count);
auto remove_result = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(remove_result.has_value());
REQUIRE(remove_result.value() == remove_count);
REQUIRE(index->GetNumElements() == base_num - remove_count);
REQUIRE(index->GetNumberRemoved() == remove_count);

// removing the same ids again should remove nothing
auto duplicate_remove = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(duplicate_remove.has_value());
REQUIRE(duplicate_remove.value() == 0);

// removed ids must not appear in search results
for (int64_t i = 0; i < remove_count; ++i) {
auto query = fixtures::get_one_query(dataset->base_, static_cast<int>(i));
auto search_result = index->KnnSearch(query, 10, search_param);
REQUIRE(search_result.has_value());
for (int64_t j = 0; j < search_result.value()->GetDim(); ++j) {
REQUIRE(search_result.value()->GetIds()[j] != ids[i]);
}
}
}
50 changes: 50 additions & 0 deletions tests/test_warp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,53 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::WarpTestIndex,
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::WarpTestIndex, "Warp Mark Remove", "[ft][remove][warp]") {
auto metric_type = GENERATE("ip");
std::string base_quantization_str = GENERATE("fp32");
WarpParam warp_param;
warp_param.base_quantization_type = base_quantization_str;
const std::string name = "warp";
auto search_param = GenerateWarpSearchParametersString();
for (auto& dim : dims) {
INFO(fmt::format("metric_type={}, dim={}", metric_type, dim));
auto param = GenerateWarpBuildParametersString(metric_type, dim, warp_param);
auto index = TestFactory(name, param, true);
auto dataset =
pool.GetDatasetAndCreate(dim, base_count, metric_type, false, 0.8, 0, 16, "multi");
TestBuildIndex(index, dataset, true);

auto base_num = dataset->base_->GetNumElements();
const auto* ids = dataset->base_->GetIds();
REQUIRE(index->GetNumElements() == base_num);
REQUIRE(index->GetNumberRemoved() == 0);

// FORCE_REMOVE is not supported by WARP
auto force_result = index->Remove(ids[0], vsag::RemoveMode::FORCE_REMOVE);
REQUIRE_FALSE(force_result.has_value());

// mark remove half of the base data
int64_t remove_count = base_num / 2;
std::vector<int64_t> remove_ids(ids, ids + remove_count);
auto remove_result = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(remove_result.has_value());
REQUIRE(remove_result.value() == remove_count);
REQUIRE(index->GetNumElements() == base_num - remove_count);
REQUIRE(index->GetNumberRemoved() == remove_count);

// removing the same ids again should remove nothing
auto duplicate_remove = index->Remove(remove_ids, vsag::RemoveMode::MARK_REMOVE);
REQUIRE(duplicate_remove.has_value());
REQUIRE(duplicate_remove.value() == 0);

// removed ids must not appear in search results
for (int64_t i = 0; i < remove_count; ++i) {
auto query = fixtures::get_one_query(dataset->base_, static_cast<int>(i));
auto search_result = index->KnnSearch(query, 10, search_param);
REQUIRE(search_result.has_value());
for (int64_t j = 0; j < search_result.value()->GetDim(); ++j) {
REQUIRE(search_result.value()->GetIds()[j] != ids[i]);
}
}
}
}
Loading