@@ -145,3 +145,216 @@ TEST_CASE("RaBitQSplitDataCell direct split compute", "[ut][RaBitQSplitDataCell]
145145 }
146146 }
147147}
148+ TEST_CASE (" RaBitQSplitDataCell serialize and methods" , " [ut][RaBitQSplitDataCell]" ) {
149+ auto allocator = SafeAllocator::FactoryDefaultAllocator ();
150+ constexpr uint64_t dim = 64 ;
151+ constexpr InnerIdType count = 32 ;
152+ auto vectors = fixtures::generate_vectors (count, dim);
153+
154+ constexpr const char * param_str = R"(
155+ {
156+ "codes_type": "rabitq_split",
157+ "io_params": {
158+ "type": "memory_io"
159+ },
160+ "quantization_params": {
161+ "type": "rabitq",
162+ "rabitq_version": "split_1bit_7bit",
163+ "rabitq_bits_per_dim_query": 32,
164+ "rabitq_bits_per_dim_base": 4
165+ }
166+ }
167+ )" ;
168+
169+ auto param_json = JsonType::Parse (param_str);
170+ auto param = std::make_shared<FlattenDataCellParameter>();
171+ param->FromJson (param_json);
172+
173+ IndexCommonParam common_param;
174+ common_param.allocator_ = allocator;
175+ common_param.dim_ = dim;
176+ common_param.metric_ = MetricType::METRIC_TYPE_L2SQR;
177+
178+ auto flatten = FlattenInterface::MakeInstance (param, common_param);
179+ flatten->Train (vectors.data (), count);
180+
181+ SECTION (" InsertVector and UpdateVector" ) {
182+ for (InnerIdType i = 0 ; i < count; ++i) {
183+ flatten->InsertVector (vectors.data () + i * dim);
184+ }
185+ REQUIRE (flatten->TotalCount () == count);
186+
187+ REQUIRE (flatten->UpdateVector (vectors.data (), 0 ) == true );
188+ REQUIRE (flatten->UpdateVector (vectors.data (), count + 10 ) == false );
189+ }
190+
191+ SECTION (" BatchInsertVector with explicit ids" ) {
192+ std::vector<InnerIdType> ids (count);
193+ std::iota (ids.begin (), ids.end (), 0 );
194+ flatten->BatchInsertVector (vectors.data (), count, ids.data ());
195+ REQUIRE (flatten->TotalCount () == count);
196+ }
197+
198+ SECTION (" Serialize and Deserialize" ) {
199+ flatten->BatchInsertVector (vectors.data (), count);
200+
201+ std::stringstream ss;
202+ IOStreamWriter writer (ss);
203+ flatten->Serialize (writer);
204+ ss.seekg (0 , std::ios::beg);
205+ IOStreamReader reader (ss);
206+
207+ auto other = FlattenInterface::MakeInstance (param, common_param);
208+ other->Train (vectors.data (), count);
209+ other->Deserialize (reader);
210+ REQUIRE (other->TotalCount () == flatten->TotalCount ());
211+
212+ auto query = fixtures::generate_vectors (1 , dim, 99 );
213+ auto computer = flatten->FactoryComputer (query.data ());
214+ std::vector<InnerIdType> idx (count);
215+ std::iota (idx.begin (), idx.end (), 0 );
216+ std::vector<float > dists1 (count), dists2 (count);
217+ flatten->Query (dists1.data (), computer, idx.data (), count);
218+ other->Query (dists2.data (), computer, idx.data (), count);
219+ for (InnerIdType i = 0 ; i < count; ++i) {
220+ REQUIRE (dists1[i] == dists2[i]);
221+ }
222+ }
223+
224+ SECTION (" GetCodesById" ) {
225+ flatten->BatchInsertVector (vectors.data (), count);
226+ bool need_release = false ;
227+ const auto * code0 = flatten->GetCodesById (0 , need_release);
228+ REQUIRE (code0 != nullptr );
229+ if (need_release) {
230+ flatten->Release (code0);
231+ }
232+ }
233+
234+ SECTION (" Encode and Decode" ) {
235+ flatten->BatchInsertVector (vectors.data (), count);
236+ auto code_size = flatten->code_size_ ;
237+ std::vector<uint8_t > codes (code_size);
238+ REQUIRE (flatten->Encode (vectors.data (), codes.data ()) == true );
239+ std::vector<float > decoded (dim);
240+ flatten->Decode (codes.data (), decoded.data ());
241+ }
242+
243+ SECTION (" Resize and ShrinkToFit" ) {
244+ flatten->BatchInsertVector (vectors.data (), count);
245+ flatten->Resize (count * 2 );
246+ flatten->ShrinkToFit (count);
247+ }
248+
249+ SECTION (" Move" ) {
250+ flatten->BatchInsertVector (vectors.data (), count);
251+ flatten->Move (0 , count);
252+ }
253+
254+ SECTION (" GetCodesById variants" ) {
255+ flatten->BatchInsertVector (vectors.data (), count);
256+ bool need_release = false ;
257+ const auto * codes = flatten->GetCodesById (0 , need_release);
258+ REQUIRE (codes != nullptr );
259+ if (need_release) {
260+ flatten->Release (codes);
261+ }
262+
263+ auto code_size = flatten->code_size_ ;
264+ std::vector<uint8_t > buf (code_size);
265+ REQUIRE (flatten->GetCodesById (0 , buf.data ()) == true );
266+ }
267+
268+ SECTION (" ExportModel" ) {
269+ flatten->BatchInsertVector (vectors.data (), count);
270+ auto other = FlattenInterface::MakeInstance (param, common_param);
271+ other->Train (vectors.data (), count);
272+ flatten->ExportModel (other);
273+ }
274+
275+ SECTION (" MergeOther" ) {
276+ flatten->BatchInsertVector (vectors.data (), count / 2 );
277+ auto other_param = std::make_shared<FlattenDataCellParameter>();
278+ other_param->FromJson (param_json);
279+ auto other = FlattenInterface::MakeInstance (other_param, common_param);
280+ other->Train (vectors.data (), count);
281+ other->BatchInsertVector (vectors.data () + (count / 2 ) * dim, count / 2 );
282+ flatten->MergeOther (other, count / 2 );
283+ REQUIRE (flatten->TotalCount () == count);
284+ }
285+
286+ SECTION (" Metadata methods" ) {
287+ REQUIRE_FALSE (flatten->GetQuantizerName ().empty ());
288+ REQUIRE (flatten->GetMetricType () == MetricType::METRIC_TYPE_L2SQR);
289+ REQUIRE (flatten->InMemory () == true );
290+ auto memory = flatten->GetMemoryUsage ();
291+ REQUIRE (memory > 0 );
292+ }
293+
294+ SECTION (" QueryWithDistanceFilter" ) {
295+ flatten->BatchInsertVector (vectors.data (), count);
296+ auto query = fixtures::generate_vectors (1 , dim, 42 );
297+ auto computer = flatten->FactoryComputer (query.data ());
298+ std::vector<InnerIdType> idx (count);
299+ std::iota (idx.begin (), idx.end (), 0 );
300+ std::vector<float > dists (count);
301+ flatten->QueryWithDistanceFilter (
302+ dists.data (), computer, idx.data (), count, std::numeric_limits<float >::max ());
303+ for (InnerIdType i = 0 ; i < count; ++i) {
304+ REQUIRE (std::isfinite (dists[i]));
305+ }
306+ }
307+ }
308+
309+ TEST_CASE (" RaBitQSplitDataCell IP metric" , " [ut][RaBitQSplitDataCell]" ) {
310+ auto allocator = SafeAllocator::FactoryDefaultAllocator ();
311+ constexpr uint64_t dim = 64 ;
312+ constexpr InnerIdType count = 16 ;
313+ auto vectors = fixtures::generate_vectors (count, dim);
314+ auto queries = fixtures::generate_vectors (2 , dim, 42 );
315+
316+ constexpr const char * param_str = R"(
317+ {
318+ "codes_type": "rabitq_split",
319+ "io_params": {
320+ "type": "memory_io"
321+ },
322+ "quantization_params": {
323+ "type": "rabitq",
324+ "rabitq_version": "split_1bit_7bit",
325+ "rabitq_bits_per_dim_query": 32,
326+ "rabitq_bits_per_dim_base": 4
327+ }
328+ }
329+ )" ;
330+
331+ auto param_json = JsonType::Parse (param_str);
332+ auto param = std::make_shared<FlattenDataCellParameter>();
333+ param->FromJson (param_json);
334+
335+ IndexCommonParam common_param;
336+ common_param.allocator_ = allocator;
337+ common_param.dim_ = dim;
338+ common_param.metric_ = MetricType::METRIC_TYPE_IP;
339+
340+ auto flatten = FlattenInterface::MakeInstance (param, common_param);
341+ flatten->Train (vectors.data (), count);
342+ flatten->BatchInsertVector (vectors.data (), count);
343+
344+ std::vector<InnerIdType> idx (count);
345+ std::iota (idx.begin (), idx.end (), 0 );
346+ std::vector<float > dists (count);
347+ std::vector<float > lower_bounds (count);
348+
349+ auto computer = flatten->FactoryComputer (queries.data ());
350+ flatten->Query (dists.data (), computer, idx.data (), count);
351+ for (InnerIdType i = 0 ; i < count; ++i) {
352+ REQUIRE (std::isfinite (dists[i]));
353+ }
354+
355+ flatten->QueryWithDistanceLowerBound (
356+ dists.data (), lower_bounds.data (), computer, idx.data (), count);
357+ for (InnerIdType i = 0 ; i < count; ++i) {
358+ REQUIRE (std::isfinite (dists[i]));
359+ }
360+ }
0 commit comments