Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions src/datacell/flatten_datacell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,216 @@ TEST_CASE("RaBitQSplitDataCell direct split compute", "[ut][RaBitQSplitDataCell]
}
}
}
TEST_CASE("RaBitQSplitDataCell serialize and methods", "[ut][RaBitQSplitDataCell]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
constexpr uint64_t dim = 64;
constexpr InnerIdType count = 32;
auto vectors = fixtures::generate_vectors(count, dim);

constexpr const char* param_str = R"(
{
"codes_type": "rabitq_split",
"io_params": {
"type": "memory_io"
},
"quantization_params": {
"type": "rabitq",
"rabitq_version": "split_1bit_7bit",
"rabitq_bits_per_dim_query": 32,
"rabitq_bits_per_dim_base": 4
}
}
)";

auto param_json = JsonType::Parse(param_str);
auto param = std::make_shared<FlattenDataCellParameter>();
param->FromJson(param_json);

IndexCommonParam common_param;
common_param.allocator_ = allocator;
common_param.dim_ = dim;
common_param.metric_ = MetricType::METRIC_TYPE_L2SQR;

auto flatten = FlattenInterface::MakeInstance(param, common_param);
flatten->Train(vectors.data(), count);

SECTION("InsertVector and UpdateVector") {
for (InnerIdType i = 0; i < count; ++i) {
flatten->InsertVector(vectors.data() + i * dim);
}
REQUIRE(flatten->TotalCount() == count);

REQUIRE(flatten->UpdateVector(vectors.data(), 0) == true);
REQUIRE(flatten->UpdateVector(vectors.data(), count + 10) == false);
}

SECTION("BatchInsertVector with explicit ids") {
std::vector<InnerIdType> ids(count);
std::iota(ids.begin(), ids.end(), 0);
flatten->BatchInsertVector(vectors.data(), count, ids.data());
REQUIRE(flatten->TotalCount() == count);
}

SECTION("Serialize and Deserialize") {
flatten->BatchInsertVector(vectors.data(), count);

std::stringstream ss;
IOStreamWriter writer(ss);
flatten->Serialize(writer);
ss.seekg(0, std::ios::beg);
IOStreamReader reader(ss);

auto other = FlattenInterface::MakeInstance(param, common_param);
other->Train(vectors.data(), count);
other->Deserialize(reader);
REQUIRE(other->TotalCount() == flatten->TotalCount());

auto query = fixtures::generate_vectors(1, dim, 99);
auto computer = flatten->FactoryComputer(query.data());
std::vector<InnerIdType> idx(count);
std::iota(idx.begin(), idx.end(), 0);
std::vector<float> dists1(count), dists2(count);
flatten->Query(dists1.data(), computer, idx.data(), count);
other->Query(dists2.data(), computer, idx.data(), count);
for (InnerIdType i = 0; i < count; ++i) {
REQUIRE(dists1[i] == dists2[i]);
}
}

SECTION("GetCodesById") {
flatten->BatchInsertVector(vectors.data(), count);
bool need_release = false;
const auto* code0 = flatten->GetCodesById(0, need_release);
REQUIRE(code0 != nullptr);
if (need_release) {
flatten->Release(code0);
}
}

SECTION("Encode and Decode") {
flatten->BatchInsertVector(vectors.data(), count);
auto code_size = flatten->code_size_;
std::vector<uint8_t> codes(code_size);
REQUIRE(flatten->Encode(vectors.data(), codes.data()) == true);
std::vector<float> decoded(dim);
flatten->Decode(codes.data(), decoded.data());
}
Comment thread
LHT129 marked this conversation as resolved.

SECTION("Resize and ShrinkToFit") {
flatten->BatchInsertVector(vectors.data(), count);
flatten->Resize(count * 2);
flatten->ShrinkToFit(count);
}

SECTION("Move") {
flatten->BatchInsertVector(vectors.data(), count);
flatten->Move(0, count);
}

SECTION("GetCodesById variants") {
flatten->BatchInsertVector(vectors.data(), count);
bool need_release = false;
const auto* codes = flatten->GetCodesById(0, need_release);
REQUIRE(codes != nullptr);
if (need_release) {
flatten->Release(codes);
}

auto code_size = flatten->code_size_;
std::vector<uint8_t> buf(code_size);
REQUIRE(flatten->GetCodesById(0, buf.data()) == true);
}

SECTION("ExportModel") {
flatten->BatchInsertVector(vectors.data(), count);
auto other = FlattenInterface::MakeInstance(param, common_param);
other->Train(vectors.data(), count);
flatten->ExportModel(other);
}

SECTION("MergeOther") {
flatten->BatchInsertVector(vectors.data(), count / 2);
auto other_param = std::make_shared<FlattenDataCellParameter>();
other_param->FromJson(param_json);
auto other = FlattenInterface::MakeInstance(other_param, common_param);
other->Train(vectors.data(), count);
other->BatchInsertVector(vectors.data() + (count / 2) * dim, count / 2);
flatten->MergeOther(other, count / 2);
REQUIRE(flatten->TotalCount() == count);
}

SECTION("Metadata methods") {
REQUIRE_FALSE(flatten->GetQuantizerName().empty());
REQUIRE(flatten->GetMetricType() == MetricType::METRIC_TYPE_L2SQR);
REQUIRE(flatten->InMemory() == true);
auto memory = flatten->GetMemoryUsage();
REQUIRE(memory > 0);
}

SECTION("QueryWithDistanceFilter") {
flatten->BatchInsertVector(vectors.data(), count);
auto query = fixtures::generate_vectors(1, dim, 42);
auto computer = flatten->FactoryComputer(query.data());
std::vector<InnerIdType> idx(count);
std::iota(idx.begin(), idx.end(), 0);
std::vector<float> dists(count);
flatten->QueryWithDistanceFilter(
dists.data(), computer, idx.data(), count, std::numeric_limits<float>::max());
for (InnerIdType i = 0; i < count; ++i) {
REQUIRE(std::isfinite(dists[i]));
}
}
}

TEST_CASE("RaBitQSplitDataCell IP metric", "[ut][RaBitQSplitDataCell]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
constexpr uint64_t dim = 64;
constexpr InnerIdType count = 16;
auto vectors = fixtures::generate_vectors(count, dim);
auto queries = fixtures::generate_vectors(2, dim, 42);

constexpr const char* param_str = R"(
{
"codes_type": "rabitq_split",
"io_params": {
"type": "memory_io"
},
"quantization_params": {
"type": "rabitq",
"rabitq_version": "split_1bit_7bit",
"rabitq_bits_per_dim_query": 32,
"rabitq_bits_per_dim_base": 4
}
}
)";

auto param_json = JsonType::Parse(param_str);
auto param = std::make_shared<FlattenDataCellParameter>();
param->FromJson(param_json);

IndexCommonParam common_param;
common_param.allocator_ = allocator;
common_param.dim_ = dim;
common_param.metric_ = MetricType::METRIC_TYPE_IP;

auto flatten = FlattenInterface::MakeInstance(param, common_param);
flatten->Train(vectors.data(), count);
flatten->BatchInsertVector(vectors.data(), count);

std::vector<InnerIdType> idx(count);
std::iota(idx.begin(), idx.end(), 0);
std::vector<float> dists(count);
std::vector<float> lower_bounds(count);

auto computer = flatten->FactoryComputer(queries.data());
flatten->Query(dists.data(), computer, idx.data(), count);
for (InnerIdType i = 0; i < count; ++i) {
REQUIRE(std::isfinite(dists[i]));
}

flatten->QueryWithDistanceLowerBound(
dists.data(), lower_bounds.data(), computer, idx.data(), count);
for (InnerIdType i = 0; i < count; ++i) {
REQUIRE(std::isfinite(dists[i]));
}
}
72 changes: 72 additions & 0 deletions src/io/mmap_io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,75 @@ TEST_CASE("MMapIO Serialize & Deserialize", "[ut][MMapIO]") {
auto rio = std::make_unique<MMapIO>(path2, allocator.get());
TestSerializeAndDeserialize(*wio, *rio);
}

TEST_CASE("MMapIO directory path error", "[ut][MMapIO]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
fixtures::TempDir dir("mmap_io_dir_test");
auto dir_path = dir.path;
REQUIRE_THROWS(std::make_unique<MMapIO>(dir_path, allocator.get()));
}

TEST_CASE("MMapIO resize shrink", "[ut][MMapIO]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
fixtures::TempDir dir("mmap_io_resize");
auto path = dir.GenerateRandomFile(false);
auto io = std::make_unique<MMapIO>(path, allocator.get());

std::vector<uint8_t> data(4096, 0xAB);
io->Write(data.data(), data.size(), 0);

io->Resize(8192);
REQUIRE(io->size_ >= 8192);

io->Resize(2048);
REQUIRE(io->size_ == 2048);

std::vector<uint8_t> read_buf(2048);
REQUIRE(io->Read(2048, 0, read_buf.data()) == true);
for (uint64_t i = 0; i < 2048; ++i) {
REQUIRE(read_buf[i] == 0xAB);
}
Comment thread
LHT129 marked this conversation as resolved.
}

TEST_CASE("MMapIO MultiRead", "[ut][MMapIO]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
fixtures::TempDir dir("mmap_io_multi");
auto path = dir.GenerateRandomFile(false);
auto io = std::make_unique<MMapIO>(path, allocator.get());

std::vector<uint8_t> data(256);
for (uint64_t i = 0; i < 256; ++i) {
data[i] = static_cast<uint8_t>(i);
}
io->Write(data.data(), data.size(), 0);

std::vector<uint64_t> sizes = {64, 64, 64};
std::vector<uint64_t> offsets = {0, 64, 128};
std::vector<uint8_t> result(192);
io->MultiRead(result.data(), sizes.data(), offsets.data(), 3);

for (uint64_t i = 0; i < 192; ++i) {
REQUIRE(result[i] == static_cast<uint8_t>(i));
}
}

TEST_CASE("MMapIO existing file", "[ut][MMapIO]") {
auto allocator = SafeAllocator::FactoryDefaultAllocator();
fixtures::TempDir dir("mmap_io_exist");
auto path = dir.GenerateRandomFile(true);

{
auto io = std::make_unique<MMapIO>(path, allocator.get());
std::vector<uint8_t> data(128, 0xCD);
io->Write(data.data(), data.size(), 0);
}

auto io2 = std::make_unique<MMapIO>(path, allocator.get());
std::vector<uint8_t> data2(64, 0xEF);
io2->Write(data2.data(), data2.size(), 0);
std::vector<uint8_t> read_buf(64);
io2->Read(64, 0, read_buf.data());
for (uint64_t i = 0; i < 64; ++i) {
REQUIRE(read_buf[i] == 0xEF);
}
}
Loading
Loading