Skip to content

Commit 20a85bb

Browse files
authored
test: add coverage tests for rabitq_split, mmap_io, warp, and c_api (#2148)
Cover previously untested paths toward 90% main branch coverage: - RaBitQSplitDataCell: InsertVector, UpdateVector, Serialize/Deserialize, Encode/Decode, Resize, ShrinkToFit, Move, ExportModel, MergeOther, QueryWithDistanceFilter, IP metric instantiation - MMapIO: directory path error, resize shrink, MultiRead, existing file - Warp: multiple dimensions with IP metric - vsag_c_api: null handles, invalid inputs, update failures, serialize edges Assisted-by: OpenCode:claude-opus-4.7 Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent 75fbaaa commit 20a85bb

4 files changed

Lines changed: 510 additions & 1 deletion

File tree

src/datacell/flatten_datacell_test.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,216 @@ TEST_CASE("RaBitQSplitDataCell direct split compute", "[ut][RaBitQSplitDataCell]
145145
}
146146
}
147147
}
148+
TEST_CASE("RaBitQSplitDataCell serialize and methods", "[ut][RaBitQSplitDataCell]") {
149+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
150+
constexpr uint64_t dim = 64;
151+
constexpr InnerIdType count = 32;
152+
auto vectors = fixtures::generate_vectors(count, dim);
153+
154+
constexpr const char* param_str = R"(
155+
{
156+
"codes_type": "rabitq_split",
157+
"io_params": {
158+
"type": "memory_io"
159+
},
160+
"quantization_params": {
161+
"type": "rabitq",
162+
"rabitq_version": "split_1bit_7bit",
163+
"rabitq_bits_per_dim_query": 32,
164+
"rabitq_bits_per_dim_base": 4
165+
}
166+
}
167+
)";
168+
169+
auto param_json = JsonType::Parse(param_str);
170+
auto param = std::make_shared<FlattenDataCellParameter>();
171+
param->FromJson(param_json);
172+
173+
IndexCommonParam common_param;
174+
common_param.allocator_ = allocator;
175+
common_param.dim_ = dim;
176+
common_param.metric_ = MetricType::METRIC_TYPE_L2SQR;
177+
178+
auto flatten = FlattenInterface::MakeInstance(param, common_param);
179+
flatten->Train(vectors.data(), count);
180+
181+
SECTION("InsertVector and UpdateVector") {
182+
for (InnerIdType i = 0; i < count; ++i) {
183+
flatten->InsertVector(vectors.data() + i * dim);
184+
}
185+
REQUIRE(flatten->TotalCount() == count);
186+
187+
REQUIRE(flatten->UpdateVector(vectors.data(), 0) == true);
188+
REQUIRE(flatten->UpdateVector(vectors.data(), count + 10) == false);
189+
}
190+
191+
SECTION("BatchInsertVector with explicit ids") {
192+
std::vector<InnerIdType> ids(count);
193+
std::iota(ids.begin(), ids.end(), 0);
194+
flatten->BatchInsertVector(vectors.data(), count, ids.data());
195+
REQUIRE(flatten->TotalCount() == count);
196+
}
197+
198+
SECTION("Serialize and Deserialize") {
199+
flatten->BatchInsertVector(vectors.data(), count);
200+
201+
std::stringstream ss;
202+
IOStreamWriter writer(ss);
203+
flatten->Serialize(writer);
204+
ss.seekg(0, std::ios::beg);
205+
IOStreamReader reader(ss);
206+
207+
auto other = FlattenInterface::MakeInstance(param, common_param);
208+
other->Train(vectors.data(), count);
209+
other->Deserialize(reader);
210+
REQUIRE(other->TotalCount() == flatten->TotalCount());
211+
212+
auto query = fixtures::generate_vectors(1, dim, 99);
213+
auto computer = flatten->FactoryComputer(query.data());
214+
std::vector<InnerIdType> idx(count);
215+
std::iota(idx.begin(), idx.end(), 0);
216+
std::vector<float> dists1(count), dists2(count);
217+
flatten->Query(dists1.data(), computer, idx.data(), count);
218+
other->Query(dists2.data(), computer, idx.data(), count);
219+
for (InnerIdType i = 0; i < count; ++i) {
220+
REQUIRE(dists1[i] == dists2[i]);
221+
}
222+
}
223+
224+
SECTION("GetCodesById") {
225+
flatten->BatchInsertVector(vectors.data(), count);
226+
bool need_release = false;
227+
const auto* code0 = flatten->GetCodesById(0, need_release);
228+
REQUIRE(code0 != nullptr);
229+
if (need_release) {
230+
flatten->Release(code0);
231+
}
232+
}
233+
234+
SECTION("Encode and Decode") {
235+
flatten->BatchInsertVector(vectors.data(), count);
236+
auto code_size = flatten->code_size_;
237+
std::vector<uint8_t> codes(code_size);
238+
REQUIRE(flatten->Encode(vectors.data(), codes.data()) == true);
239+
std::vector<float> decoded(dim);
240+
flatten->Decode(codes.data(), decoded.data());
241+
}
242+
243+
SECTION("Resize and ShrinkToFit") {
244+
flatten->BatchInsertVector(vectors.data(), count);
245+
flatten->Resize(count * 2);
246+
flatten->ShrinkToFit(count);
247+
}
248+
249+
SECTION("Move") {
250+
flatten->BatchInsertVector(vectors.data(), count);
251+
flatten->Move(0, count);
252+
}
253+
254+
SECTION("GetCodesById variants") {
255+
flatten->BatchInsertVector(vectors.data(), count);
256+
bool need_release = false;
257+
const auto* codes = flatten->GetCodesById(0, need_release);
258+
REQUIRE(codes != nullptr);
259+
if (need_release) {
260+
flatten->Release(codes);
261+
}
262+
263+
auto code_size = flatten->code_size_;
264+
std::vector<uint8_t> buf(code_size);
265+
REQUIRE(flatten->GetCodesById(0, buf.data()) == true);
266+
}
267+
268+
SECTION("ExportModel") {
269+
flatten->BatchInsertVector(vectors.data(), count);
270+
auto other = FlattenInterface::MakeInstance(param, common_param);
271+
other->Train(vectors.data(), count);
272+
flatten->ExportModel(other);
273+
}
274+
275+
SECTION("MergeOther") {
276+
flatten->BatchInsertVector(vectors.data(), count / 2);
277+
auto other_param = std::make_shared<FlattenDataCellParameter>();
278+
other_param->FromJson(param_json);
279+
auto other = FlattenInterface::MakeInstance(other_param, common_param);
280+
other->Train(vectors.data(), count);
281+
other->BatchInsertVector(vectors.data() + (count / 2) * dim, count / 2);
282+
flatten->MergeOther(other, count / 2);
283+
REQUIRE(flatten->TotalCount() == count);
284+
}
285+
286+
SECTION("Metadata methods") {
287+
REQUIRE_FALSE(flatten->GetQuantizerName().empty());
288+
REQUIRE(flatten->GetMetricType() == MetricType::METRIC_TYPE_L2SQR);
289+
REQUIRE(flatten->InMemory() == true);
290+
auto memory = flatten->GetMemoryUsage();
291+
REQUIRE(memory > 0);
292+
}
293+
294+
SECTION("QueryWithDistanceFilter") {
295+
flatten->BatchInsertVector(vectors.data(), count);
296+
auto query = fixtures::generate_vectors(1, dim, 42);
297+
auto computer = flatten->FactoryComputer(query.data());
298+
std::vector<InnerIdType> idx(count);
299+
std::iota(idx.begin(), idx.end(), 0);
300+
std::vector<float> dists(count);
301+
flatten->QueryWithDistanceFilter(
302+
dists.data(), computer, idx.data(), count, std::numeric_limits<float>::max());
303+
for (InnerIdType i = 0; i < count; ++i) {
304+
REQUIRE(std::isfinite(dists[i]));
305+
}
306+
}
307+
}
308+
309+
TEST_CASE("RaBitQSplitDataCell IP metric", "[ut][RaBitQSplitDataCell]") {
310+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
311+
constexpr uint64_t dim = 64;
312+
constexpr InnerIdType count = 16;
313+
auto vectors = fixtures::generate_vectors(count, dim);
314+
auto queries = fixtures::generate_vectors(2, dim, 42);
315+
316+
constexpr const char* param_str = R"(
317+
{
318+
"codes_type": "rabitq_split",
319+
"io_params": {
320+
"type": "memory_io"
321+
},
322+
"quantization_params": {
323+
"type": "rabitq",
324+
"rabitq_version": "split_1bit_7bit",
325+
"rabitq_bits_per_dim_query": 32,
326+
"rabitq_bits_per_dim_base": 4
327+
}
328+
}
329+
)";
330+
331+
auto param_json = JsonType::Parse(param_str);
332+
auto param = std::make_shared<FlattenDataCellParameter>();
333+
param->FromJson(param_json);
334+
335+
IndexCommonParam common_param;
336+
common_param.allocator_ = allocator;
337+
common_param.dim_ = dim;
338+
common_param.metric_ = MetricType::METRIC_TYPE_IP;
339+
340+
auto flatten = FlattenInterface::MakeInstance(param, common_param);
341+
flatten->Train(vectors.data(), count);
342+
flatten->BatchInsertVector(vectors.data(), count);
343+
344+
std::vector<InnerIdType> idx(count);
345+
std::iota(idx.begin(), idx.end(), 0);
346+
std::vector<float> dists(count);
347+
std::vector<float> lower_bounds(count);
348+
349+
auto computer = flatten->FactoryComputer(queries.data());
350+
flatten->Query(dists.data(), computer, idx.data(), count);
351+
for (InnerIdType i = 0; i < count; ++i) {
352+
REQUIRE(std::isfinite(dists[i]));
353+
}
354+
355+
flatten->QueryWithDistanceLowerBound(
356+
dists.data(), lower_bounds.data(), computer, idx.data(), count);
357+
for (InnerIdType i = 0; i < count; ++i) {
358+
REQUIRE(std::isfinite(dists[i]));
359+
}
360+
}

src/io/mmap_io_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,75 @@ TEST_CASE("MMapIO Serialize & Deserialize", "[ut][MMapIO]") {
6060
auto rio = std::make_unique<MMapIO>(path2, allocator.get());
6161
TestSerializeAndDeserialize(*wio, *rio);
6262
}
63+
64+
TEST_CASE("MMapIO directory path error", "[ut][MMapIO]") {
65+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
66+
fixtures::TempDir dir("mmap_io_dir_test");
67+
auto dir_path = dir.path;
68+
REQUIRE_THROWS(std::make_unique<MMapIO>(dir_path, allocator.get()));
69+
}
70+
71+
TEST_CASE("MMapIO resize shrink", "[ut][MMapIO]") {
72+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
73+
fixtures::TempDir dir("mmap_io_resize");
74+
auto path = dir.GenerateRandomFile(false);
75+
auto io = std::make_unique<MMapIO>(path, allocator.get());
76+
77+
std::vector<uint8_t> data(4096, 0xAB);
78+
io->Write(data.data(), data.size(), 0);
79+
80+
io->Resize(8192);
81+
REQUIRE(io->size_ >= 8192);
82+
83+
io->Resize(2048);
84+
REQUIRE(io->size_ == 2048);
85+
86+
std::vector<uint8_t> read_buf(2048);
87+
REQUIRE(io->Read(2048, 0, read_buf.data()) == true);
88+
for (uint64_t i = 0; i < 2048; ++i) {
89+
REQUIRE(read_buf[i] == 0xAB);
90+
}
91+
}
92+
93+
TEST_CASE("MMapIO MultiRead", "[ut][MMapIO]") {
94+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
95+
fixtures::TempDir dir("mmap_io_multi");
96+
auto path = dir.GenerateRandomFile(false);
97+
auto io = std::make_unique<MMapIO>(path, allocator.get());
98+
99+
std::vector<uint8_t> data(256);
100+
for (uint64_t i = 0; i < 256; ++i) {
101+
data[i] = static_cast<uint8_t>(i);
102+
}
103+
io->Write(data.data(), data.size(), 0);
104+
105+
std::vector<uint64_t> sizes = {64, 64, 64};
106+
std::vector<uint64_t> offsets = {0, 64, 128};
107+
std::vector<uint8_t> result(192);
108+
io->MultiRead(result.data(), sizes.data(), offsets.data(), 3);
109+
110+
for (uint64_t i = 0; i < 192; ++i) {
111+
REQUIRE(result[i] == static_cast<uint8_t>(i));
112+
}
113+
}
114+
115+
TEST_CASE("MMapIO existing file", "[ut][MMapIO]") {
116+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
117+
fixtures::TempDir dir("mmap_io_exist");
118+
auto path = dir.GenerateRandomFile(true);
119+
120+
{
121+
auto io = std::make_unique<MMapIO>(path, allocator.get());
122+
std::vector<uint8_t> data(128, 0xCD);
123+
io->Write(data.data(), data.size(), 0);
124+
}
125+
126+
auto io2 = std::make_unique<MMapIO>(path, allocator.get());
127+
std::vector<uint8_t> data2(64, 0xEF);
128+
io2->Write(data2.data(), data2.size(), 0);
129+
std::vector<uint8_t> read_buf(64);
130+
io2->Read(64, 0, read_buf.data());
131+
for (uint64_t i = 0; i < 64; ++i) {
132+
REQUIRE(read_buf[i] == 0xEF);
133+
}
134+
}

0 commit comments

Comments
 (0)