From 3394d98020764399110388894cce1b88dd3a10bb Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 5 Jun 2026 15:58:34 +0800 Subject: [PATCH] refactor(parameter): introduce macros to eliminate CheckCompatibility boilerplate Add PARAM_CAST_OR_RETURN, CHECK_FIELD_EQ, and CHECK_SUB_PARAM macros in src/utils/param_compat_macros.h to replace the repetitive dynamic_pointer_cast + logger::error + return false pattern that was duplicated across 23 parameter files. Signed-off-by: LHT129 Co-authored-by: opencode --- .../bruteforce/bruteforce_parameter.cpp | 13 ++-- src/algorithm/hgraph/hgraph_parameter.cpp | 59 ++++++------------ src/algorithm/inner_index_parameter.cpp | 37 +++-------- src/algorithm/ivf/gno_imi_parameter.cpp | 28 ++------- src/algorithm/ivf/ivf_parameter.cpp | 28 ++------- .../ivf/ivf_partition_strategy_parameter.cpp | 23 ++----- src/algorithm/pyramid/pyramid_zparameters.cpp | 57 ++++------------- src/algorithm/sindi/sindi_parameter.cpp | 34 +++------- src/algorithm/warp/warp_parameter.cpp | 10 +-- ...attribute_inverted_interface_parameter.cpp | 9 ++- src/datacell/bucket_datacell_parameter.cpp | 37 ++--------- .../compressed_graph_datacell_parameter.h | 37 ++--------- src/datacell/flatten_datacell_parameter.cpp | 18 ++---- src/datacell/graph_datacell_parameter.cpp | 49 +++------------ .../multi_vector_datacell_parameter.h | 10 +-- .../sparse_graph_datacell_parameter.cpp | 52 +++------------- .../sparse_vector_datacell_parameter.h | 12 ++-- .../vector_transformer_parameter.cpp | 21 ++----- .../pq_fastscan_quantizer_parameter.cpp | 19 +----- .../product_quantizer_parameter.cpp | 28 ++------- .../rabitq_quantizer_parameter.cpp | 62 +++---------------- .../sq4_uniform_quantizer_parameter.cpp | 20 +----- .../transform_quantizer_parameter.cpp | 21 ++----- src/utils/param_compat_macros.h | 45 ++++++++++++++ 24 files changed, 180 insertions(+), 549 deletions(-) create mode 100644 src/utils/param_compat_macros.h diff --git a/src/algorithm/bruteforce/bruteforce_parameter.cpp b/src/algorithm/bruteforce/bruteforce_parameter.cpp index fd77538548..ce6aaf8ced 100644 --- a/src/algorithm/bruteforce/bruteforce_parameter.cpp +++ b/src/algorithm/bruteforce/bruteforce_parameter.cpp @@ -18,8 +18,8 @@ #include #include "datacell/flatten_datacell_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -49,13 +49,8 @@ BruteForceParameter::CheckCompatibility(const ParamPtr& other) const { if (not InnerIndexParameter::CheckCompatibility(other)) { return false; } - auto brute_force_param = std::dynamic_pointer_cast(other); - if (not brute_force_param) { - logger::error( - "BruteForceParameter::CheckCompatibility: " - "other parameter is not a BruteForceParameter"); - return false; - } - return this->base_codes_param->CheckCompatibility(brute_force_param->base_codes_param); + PARAM_CAST_OR_RETURN(BruteForceParameter, p, other); + CHECK_SUB_PARAM(*this, *p, base_codes_param); + return true; } } // namespace vsag diff --git a/src/algorithm/hgraph/hgraph_parameter.cpp b/src/algorithm/hgraph/hgraph_parameter.cpp index 0437833301..b0e35e31c7 100644 --- a/src/algorithm/hgraph/hgraph_parameter.cpp +++ b/src/algorithm/hgraph/hgraph_parameter.cpp @@ -23,6 +23,7 @@ #include "datacell/sparse_vector_datacell_parameter.h" #include "impl/odescent/odescent_graph_parameter.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -158,55 +159,31 @@ HGraphParameter::ToJson() const { bool HGraphParameter::CheckCompatibility(const ParamPtr& other) const { - auto hgraph_param = std::dynamic_pointer_cast(other); - if (hgraph_param == nullptr) { - logger::error("HGraphParameter::CheckCompatibility: other is not HGraphParameter"); - return false; - } + PARAM_CAST_OR_RETURN(HGraphParameter, p, other); auto have_reorder = this->use_reorder && not this->ignore_reorder; - auto have_reorder_other = hgraph_param->use_reorder && not hgraph_param->ignore_reorder; + auto have_reorder_other = p->use_reorder && not p->ignore_reorder; if (have_reorder != have_reorder_other) { logger::error( "HGraphParameter::CheckCompatibility: use_reorder and ignore_reorder must be the same"); return false; } - if (not this->base_codes_param->CheckCompatibility(hgraph_param->base_codes_param)) { - logger::error("HGraphParameter::CheckCompatibility: base_codes_param is not compatible"); - return false; - } - if (have_reorder && this->reorder_source != hgraph_param->reorder_source) { - logger::error("HGraphParameter::CheckCompatibility: reorder_source is not compatible"); - return false; - } - if (have_reorder && this->reorder_source != HGRAPH_REORDER_SOURCE_BASE) { - if (not this->precise_codes_param || - not this->precise_codes_param->CheckCompatibility(hgraph_param->precise_codes_param)) { - logger::error( - "HGraphParameter::CheckCompatibility: precise_codes_param is not compatible"); - return false; + CHECK_SUB_PARAM(*this, *p, base_codes_param); + if (have_reorder) { + CHECK_FIELD_EQ(*this, *p, reorder_source); + if (this->reorder_source != HGRAPH_REORDER_SOURCE_BASE) { + if (not this->precise_codes_param || + not this->precise_codes_param->CheckCompatibility(p->precise_codes_param)) { + logger::error( + "HGraphParameter::CheckCompatibility: precise_codes_param is not compatible"); + return false; + } } } - if (not this->bottom_graph_param->CheckCompatibility(hgraph_param->bottom_graph_param)) { - logger::error("HGraphParameter::CheckCompatibility: bottom_graph_param is not compatible"); - return false; - } - if (use_attribute_filter != hgraph_param->use_attribute_filter) { - logger::error("HGraphParameter::CheckCompatibility: use_attribute_filter must be the same"); - return false; - } - if (support_duplicate != hgraph_param->support_duplicate) { - logger::error("HGraphParameter::CheckCompatibility: support_duplicate must be the same"); - return false; - } - if (duplicate_distance_threshold != hgraph_param->duplicate_distance_threshold) { - logger::error( - "HGraphParameter::CheckCompatibility: duplicate_distance_threshold must be the same"); - return false; - } - if (support_force_remove != hgraph_param->support_force_remove) { - logger::error("HGraphParameter::CheckCompatibility: support_force_remove must be the same"); - return false; - } + CHECK_SUB_PARAM(*this, *p, bottom_graph_param); + CHECK_FIELD_EQ(*this, *p, use_attribute_filter); + CHECK_FIELD_EQ(*this, *p, support_duplicate); + CHECK_FIELD_EQ(*this, *p, duplicate_distance_threshold); + CHECK_FIELD_EQ(*this, *p, support_force_remove); return true; } diff --git a/src/algorithm/inner_index_parameter.cpp b/src/algorithm/inner_index_parameter.cpp index 89f060d3b4..6792263a02 100644 --- a/src/algorithm/inner_index_parameter.cpp +++ b/src/algorithm/inner_index_parameter.cpp @@ -21,6 +21,7 @@ #include "datacell/flatten_datacell_parameter.h" #include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -145,38 +146,14 @@ InnerIndexParameter::ToJson() const { } bool InnerIndexParameter::CheckCompatibility(const ParamPtr& other) const { - auto inner_index_param = std::dynamic_pointer_cast(other); - if (not inner_index_param) { - logger::error( - "InnerIndexParameter::CheckCompatibility: other parameter is not InnerIndexParameter"); - return false; - } - if (this->use_reorder != inner_index_param->use_reorder) { - logger::error("InnerIndexParameter::CheckCompatibility: use_reorder mismatch"); - return false; - } - if (this->reorder_source != inner_index_param->reorder_source) { - logger::error("InnerIndexParameter::CheckCompatibility: reorder_source mismatch"); - return false; - } + PARAM_CAST_OR_RETURN(InnerIndexParameter, p, other); + CHECK_FIELD_EQ(*this, *p, use_reorder); + CHECK_FIELD_EQ(*this, *p, reorder_source); if (this->use_reorder && this->reorder_source != HGRAPH_REORDER_SOURCE_BASE) { - if (not this->precise_codes_param->CheckCompatibility( - inner_index_param->precise_codes_param)) { - logger::error("InnerIndexParameter::CheckCompatibility: precise_codes_param mismatch"); - return false; - } + CHECK_SUB_PARAM(*this, *p, precise_codes_param); } - - if (this->use_attribute_filter != inner_index_param->use_attribute_filter) { - logger::error("InnerIndexParameter::CheckCompatibility: use_attribute_filter mismatch"); - return false; - } - - if (this->label_remap_type != inner_index_param->label_remap_type) { - logger::error("InnerIndexParameter::CheckCompatibility: label_remap_type mismatch"); - return false; - } - + CHECK_FIELD_EQ(*this, *p, use_attribute_filter); + CHECK_FIELD_EQ(*this, *p, label_remap_type); return true; } } // namespace vsag diff --git a/src/algorithm/ivf/gno_imi_parameter.cpp b/src/algorithm/ivf/gno_imi_parameter.cpp index 142ae5600e..cff2aa21ea 100644 --- a/src/algorithm/ivf/gno_imi_parameter.cpp +++ b/src/algorithm/ivf/gno_imi_parameter.cpp @@ -17,8 +17,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -49,29 +49,9 @@ GNOIMIParameter::ToJson() const { } bool GNOIMIParameter::CheckCompatibility(const ParamPtr& other) const { - auto gno_imi_param = std::dynamic_pointer_cast(other); - if (!gno_imi_param) { - logger::error( - "GNOIMIParameter::CheckCompatibility: " - "other parameter is not GNOIMIParameter"); - return false; - } - if (this->first_order_buckets_count != gno_imi_param->first_order_buckets_count) { - logger::error( - "GNOIMIParameter::CheckCompatibility: " - "first_order_buckets_count mismatch: {} != {}", - this->first_order_buckets_count, - gno_imi_param->first_order_buckets_count); - return false; - } - if (this->second_order_buckets_count != gno_imi_param->second_order_buckets_count) { - logger::error( - "GNOIMIParameter::CheckCompatibility: " - "second_order_buckets_count mismatch: {} != {}", - this->second_order_buckets_count, - gno_imi_param->second_order_buckets_count); - return false; - } + PARAM_CAST_OR_RETURN(GNOIMIParameter, p, other); + CHECK_FIELD_EQ(*this, *p, first_order_buckets_count); + CHECK_FIELD_EQ(*this, *p, second_order_buckets_count); return true; } } // namespace vsag diff --git a/src/algorithm/ivf/ivf_parameter.cpp b/src/algorithm/ivf/ivf_parameter.cpp index 6188be7f26..d6f1b83e5c 100644 --- a/src/algorithm/ivf/ivf_parameter.cpp +++ b/src/algorithm/ivf/ivf_parameter.cpp @@ -18,6 +18,7 @@ #include #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -64,29 +65,10 @@ IVFParameter::CheckCompatibility(const ParamPtr& other) const { if (not InnerIndexParameter::CheckCompatibility(other)) { return false; } - auto ivf_param = std::dynamic_pointer_cast(other); - if (not ivf_param) { - logger::error("IVFParameter::CheckCompatibility: other parameter is not IVFParameter"); - return false; - } - - if (this->buckets_per_data != ivf_param->buckets_per_data) { - logger::error("IVFParameter::CheckCompatibility: buckets_per_data mismatch"); - return false; - } - - if (not this->bucket_param->CheckCompatibility(ivf_param->bucket_param)) { - logger::error("IVFParameter::CheckCompatibility: bucket_param mismatch"); - return false; - } - - if (not this->ivf_partition_strategy_parameter->CheckCompatibility( - ivf_param->ivf_partition_strategy_parameter)) { - logger::error( - "IVFParameter::CheckCompatibility: ivf_partition_strategy_parameter " - "mismatch"); - return false; - } + PARAM_CAST_OR_RETURN(IVFParameter, p, other); + CHECK_FIELD_EQ(*this, *p, buckets_per_data); + CHECK_SUB_PARAM(*this, *p, bucket_param); + CHECK_SUB_PARAM(*this, *p, ivf_partition_strategy_parameter); return true; } diff --git a/src/algorithm/ivf/ivf_partition_strategy_parameter.cpp b/src/algorithm/ivf/ivf_partition_strategy_parameter.cpp index f6c6d5035b..294f2f168b 100644 --- a/src/algorithm/ivf/ivf_partition_strategy_parameter.cpp +++ b/src/algorithm/ivf/ivf_partition_strategy_parameter.cpp @@ -19,8 +19,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -75,23 +75,10 @@ IVFPartitionStrategyParameters::ToJson() const { bool IVFPartitionStrategyParameters::CheckCompatibility(const ParamPtr& other) const { - auto ivf_partition_param = std::dynamic_pointer_cast(other); - if (not ivf_partition_param) { - logger::error( - "IVFPartitionStrategyParameters::CheckCompatibility: other parameter is not " - "IVFPartitionStrategyParameters"); - return false; - } - if (partition_strategy_type != ivf_partition_param->partition_strategy_type) { - std::string message = fmt::format( - "IVFPartitionStrategyParameters::CheckCompatibility: partition strategy type mismatch, " - "this: {}, other: {}", - (int)partition_strategy_type, - (int)ivf_partition_param->partition_strategy_type); - logger::error(message); - return false; - } - return this->gnoimi_param->CheckCompatibility(ivf_partition_param->gnoimi_param); + PARAM_CAST_OR_RETURN(IVFPartitionStrategyParameters, p, other); + CHECK_FIELD_EQ(*this, *p, partition_strategy_type); + CHECK_SUB_PARAM(*this, *p, gnoimi_param); + return true; } } // namespace vsag diff --git a/src/algorithm/pyramid/pyramid_zparameters.cpp b/src/algorithm/pyramid/pyramid_zparameters.cpp index ce5be52aef..1cd0f59adb 100644 --- a/src/algorithm/pyramid/pyramid_zparameters.cpp +++ b/src/algorithm/pyramid/pyramid_zparameters.cpp @@ -20,6 +20,7 @@ #include "index/diskann_zparameters.h" #include "io/memory_io_parameter.h" #include "quantization/fp32_quantizer_parameter.h" +#include "utils/param_compat_macros.h" // NOLINTBEGIN(readability-simplify-boolean-expr) @@ -95,55 +96,21 @@ PyramidParameters::ToJson() const { bool PyramidParameters::CheckCompatibility(const ParamPtr& other) const { - auto pyramid_param = std::dynamic_pointer_cast(other); - if (not pyramid_param) { - logger::error( - "PyramidParameters::CheckCompatibility: other parameter is not PyramidParameters"); - return false; - } - if (not graph_param->CheckCompatibility(pyramid_param->graph_param)) { - logger::error("PyramidParameters::CheckCompatibility: graph parameters are not compatible"); - return false; - } - - if (not base_codes_param->CheckCompatibility(pyramid_param->base_codes_param)) { - logger::error( - "PyramidParameters::CheckCompatibility: flatten data cell parameters are not " - "compatible"); - return false; - } - if (no_build_levels.size() != pyramid_param->no_build_levels.size() || - not std::is_permutation(no_build_levels.begin(), - no_build_levels.end(), - pyramid_param->no_build_levels.begin())) { + PARAM_CAST_OR_RETURN(PyramidParameters, p, other); + CHECK_SUB_PARAM(*this, *p, graph_param); + CHECK_SUB_PARAM(*this, *p, base_codes_param); + if (no_build_levels.size() != p->no_build_levels.size() || + not std::is_permutation( + no_build_levels.begin(), no_build_levels.end(), p->no_build_levels.begin())) { logger::error("PyramidParameters::CheckCompatibility: no_build_levels are not compatible"); return false; } - - if (pyramid_param->use_reorder != this->use_reorder) { - logger::error( - "PyramidParameters::CheckCompatibility: use_reorder settings are not compatible"); - return false; - } - - if (this->use_reorder && - not precise_codes_param->CheckCompatibility(pyramid_param->precise_codes_param)) { - logger::error( - "PyramidParameters::CheckCompatibility: precise_codes_param are not compatible"); - return false; - } - - if (this->index_min_size != pyramid_param->index_min_size) { - logger::error("PyramidParameters::CheckCompatibility: index_min_size are not compatible"); - return false; - } - - if (this->support_duplicate != pyramid_param->support_duplicate) { - logger::error( - "PyramidParameters::CheckCompatibility: support_duplicate are not compatible"); - return false; + CHECK_FIELD_EQ(*this, *p, use_reorder); + if (this->use_reorder) { + CHECK_SUB_PARAM(*this, *p, precise_codes_param); } - + CHECK_FIELD_EQ(*this, *p, index_min_size); + CHECK_FIELD_EQ(*this, *p, support_duplicate); return true; } diff --git a/src/algorithm/sindi/sindi_parameter.cpp b/src/algorithm/sindi/sindi_parameter.cpp index 05fb99a53e..748ca1cf44 100644 --- a/src/algorithm/sindi/sindi_parameter.cpp +++ b/src/algorithm/sindi/sindi_parameter.cpp @@ -16,6 +16,7 @@ #include "sindi_parameter.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -97,31 +98,14 @@ SINDIParameter::ToJson() const { bool SINDIParameter::CheckCompatibility(const vsag::ParamPtr& other) const { - auto sindi_param = std::dynamic_pointer_cast(other); - if (sindi_param == nullptr) { - return false; - } - if (this->term_id_limit != sindi_param->term_id_limit) { - return false; - } - if (this->window_size != sindi_param->window_size) { - return false; - } - if (this->doc_prune_ratio != sindi_param->doc_prune_ratio) { - return false; - } - if (this->use_reorder != sindi_param->use_reorder) { - return false; - } - if (this->use_quantization != sindi_param->use_quantization) { - return false; - } - if (this->avg_doc_term_length != sindi_param->avg_doc_term_length) { - return false; - } - if (this->remap_term_ids != sindi_param->remap_term_ids) { - return false; - } + PARAM_CAST_OR_RETURN(SINDIParameter, p, other); + CHECK_FIELD_EQ(*this, *p, term_id_limit); + CHECK_FIELD_EQ(*this, *p, window_size); + CHECK_FIELD_EQ(*this, *p, doc_prune_ratio); + CHECK_FIELD_EQ(*this, *p, use_reorder); + CHECK_FIELD_EQ(*this, *p, use_quantization); + CHECK_FIELD_EQ(*this, *p, avg_doc_term_length); + CHECK_FIELD_EQ(*this, *p, remap_term_ids); return true; } diff --git a/src/algorithm/warp/warp_parameter.cpp b/src/algorithm/warp/warp_parameter.cpp index 2accb9dcd0..a94be16cbd 100644 --- a/src/algorithm/warp/warp_parameter.cpp +++ b/src/algorithm/warp/warp_parameter.cpp @@ -17,6 +17,8 @@ #include +#include "utils/param_compat_macros.h" + namespace vsag { void @@ -41,11 +43,9 @@ WarpParameter::CheckCompatibility(const vsag::ParamPtr& other) const { if (not InnerIndexParameter::CheckCompatibility(other)) { return false; } - auto other_param = std::dynamic_pointer_cast(other); - if (other_param == nullptr) { - return false; - } - return this->base_codes_param->CheckCompatibility(other_param->base_codes_param); + PARAM_CAST_OR_RETURN(WarpParameter, p, other); + CHECK_SUB_PARAM(*this, *p, base_codes_param); + return true; } } // namespace vsag diff --git a/src/datacell/attribute_inverted_interface_parameter.cpp b/src/datacell/attribute_inverted_interface_parameter.cpp index e5733e3314..891d6ef20a 100644 --- a/src/datacell/attribute_inverted_interface_parameter.cpp +++ b/src/datacell/attribute_inverted_interface_parameter.cpp @@ -16,6 +16,7 @@ #include "attribute_inverted_interface_parameter.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -34,11 +35,9 @@ AttributeInvertedInterfaceParameter::ToJson() const { } bool AttributeInvertedInterfaceParameter::CheckCompatibility(const ParamPtr& other) const { - auto other_param = std::dynamic_pointer_cast(other); - if (other_param == nullptr) { - return false; - } - return has_buckets_ == other_param->has_buckets_; + PARAM_CAST_OR_RETURN(AttributeInvertedInterfaceParameter, p, other); + CHECK_FIELD_EQ(*this, *p, has_buckets_); + return true; } } // namespace vsag diff --git a/src/datacell/bucket_datacell_parameter.cpp b/src/datacell/bucket_datacell_parameter.cpp index d4bdc5a8af..a007c2747b 100644 --- a/src/datacell/bucket_datacell_parameter.cpp +++ b/src/datacell/bucket_datacell_parameter.cpp @@ -17,8 +17,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { BucketDataCellParameter::BucketDataCellParameter() = default; @@ -55,37 +55,10 @@ BucketDataCellParameter::ToJson() const { } bool BucketDataCellParameter::CheckCompatibility(const ParamPtr& other) const { - auto bucket_param = std::dynamic_pointer_cast(other); - if (not bucket_param) { - logger::error( - "BucketDataCellParameter::CheckCompatibility: other parameter is not a " - "BucketDataCellParameter"); - return false; - } - - if (not this->quantizer_parameter->CheckCompatibility(bucket_param->quantizer_parameter)) { - logger::error( - "BucketDataCellParameter::CheckCompatibility: quantizer parameters are not compatible"); - return false; - } - - if (buckets_count != bucket_param->buckets_count) { - logger::error( - "BucketDataCellParameter::CheckCompatibility: buckets count is not compatible: {} != " - "{}", - buckets_count, - bucket_param->buckets_count); - return false; - } - - if (use_residual_ != bucket_param->use_residual_) { - logger::error( - "BucketDataCellParameter::CheckCompatibility: use residual is not compatible: {} != {}", - use_residual_, - bucket_param->use_residual_); - return false; - } - + PARAM_CAST_OR_RETURN(BucketDataCellParameter, p, other); + CHECK_SUB_PARAM(*this, *p, quantizer_parameter); + CHECK_FIELD_EQ(*this, *p, buckets_count); + CHECK_FIELD_EQ(*this, *p, use_residual_); return true; } } // namespace vsag diff --git a/src/datacell/compressed_graph_datacell_parameter.h b/src/datacell/compressed_graph_datacell_parameter.h index 811eac7f40..1057f19493 100644 --- a/src/datacell/compressed_graph_datacell_parameter.h +++ b/src/datacell/compressed_graph_datacell_parameter.h @@ -16,8 +16,8 @@ #pragma once #include "graph_interface_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "utils/pointer_define.h" namespace vsag { @@ -53,37 +53,10 @@ class CompressedGraphDatacellParameter : public GraphInterfaceParameter { bool CheckCompatibility(const vsag::ParamPtr& other) const override { - auto graph_param = std::dynamic_pointer_cast(other); - if (not graph_param) { - logger::error( - "CompressedGraphDatacellParameter::CheckCompatibility: other parameter " - "is not a CompressedGraphDatacellParameter"); - return false; - } - if (max_degree_ != graph_param->max_degree_) { - logger::error( - "CompressedGraphDatacellParameter::CheckCompatibility: max_degree_ " - "mismatch: {} vs {}", - max_degree_, - graph_param->max_degree_); - return false; - } - if (support_duplicate_ != graph_param->support_duplicate_) { - logger::error( - "CompressedGraphDatacellParameter::CheckCompatibility: " - "support_duplicate_ mismatch: {} vs {}", - support_duplicate_, - graph_param->support_duplicate_); - return false; - } - if (use_reverse_edges_ != graph_param->use_reverse_edges_) { - logger::error( - "CompressedGraphDatacellParameter::CheckCompatibility: " - "use_reverse_edges_ mismatch: {} vs {}", - use_reverse_edges_, - graph_param->use_reverse_edges_); - return false; - } + PARAM_CAST_OR_RETURN(CompressedGraphDatacellParameter, p, other); + CHECK_FIELD_EQ(*this, *p, max_degree_); + CHECK_FIELD_EQ(*this, *p, support_duplicate_); + CHECK_FIELD_EQ(*this, *p, use_reverse_edges_); return true; } }; diff --git a/src/datacell/flatten_datacell_parameter.cpp b/src/datacell/flatten_datacell_parameter.cpp index 009939948b..a647b228ec 100644 --- a/src/datacell/flatten_datacell_parameter.cpp +++ b/src/datacell/flatten_datacell_parameter.cpp @@ -17,8 +17,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { FlattenDataCellParameter::FlattenDataCellParameter() @@ -53,17 +53,9 @@ FlattenDataCellParameter::ToJson() const { } bool FlattenDataCellParameter::CheckCompatibility(const ParamPtr& other) const { - auto flatten_other = std::dynamic_pointer_cast(other); - if (not flatten_other) { - logger::error( - "FlattenDataCellParameter::CheckCompatibility: " - "other parameter is not FlattenDataCellParameter"); - return false; - } - if (this->name != flatten_other->name) { - logger::error("FlattenDataCellParameter::CheckCompatibility: codes_type mismatch"); - return false; - } - return this->quantizer_parameter->CheckCompatibility(flatten_other->quantizer_parameter); + PARAM_CAST_OR_RETURN(FlattenDataCellParameter, p, other); + CHECK_FIELD_EQ(*this, *p, name); + CHECK_SUB_PARAM(*this, *p, quantizer_parameter); + return true; } } // namespace vsag diff --git a/src/datacell/graph_datacell_parameter.cpp b/src/datacell/graph_datacell_parameter.cpp index 901cf2de4e..40dfbc7617 100644 --- a/src/datacell/graph_datacell_parameter.cpp +++ b/src/datacell/graph_datacell_parameter.cpp @@ -17,8 +17,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "vsag/constants.h" namespace vsag { @@ -61,47 +61,12 @@ GraphDataCellParameter::ToJson() const { } bool GraphDataCellParameter::CheckCompatibility(const ParamPtr& other) const { - auto graph_param = std::dynamic_pointer_cast(other); - if (not graph_param) { - logger::error( - "GraphDataCellParameter::CheckCompatibility: other parameter is not a " - "GraphDataCellParameter"); - return false; - } - if (max_degree_ != graph_param->max_degree_) { - logger::error("GraphDataCellParameter::CheckCompatibility: max_degree_ mismatch: {} vs {}", - max_degree_, - graph_param->max_degree_); - return false; - } - if (support_remove_ != graph_param->support_remove_) { - logger::error( - "GraphDataCellParameter::CheckCompatibility: support_remove_ mismatch: {} vs {}", - support_remove_, - graph_param->support_remove_); - return false; - } - if (remove_flag_bit_ != graph_param->remove_flag_bit_) { - logger::error( - "GraphDataCellParameter::CheckCompatibility: remove_flag_bit_ mismatch: {} vs {}", - remove_flag_bit_, - graph_param->remove_flag_bit_); - return false; - } - if (use_reverse_edges_ != graph_param->use_reverse_edges_) { - logger::error( - "GraphDataCellParameter::CheckCompatibility: use_reverse_edges_ mismatch: {} vs {}", - use_reverse_edges_, - graph_param->use_reverse_edges_); - return false; - } - if (support_duplicate_ != graph_param->support_duplicate_) { - logger::error( - "GraphDataCellParameter::CheckCompatibility: support_duplicate_ mismatch: {} vs {}", - support_duplicate_, - graph_param->support_duplicate_); - return false; - } + PARAM_CAST_OR_RETURN(GraphDataCellParameter, p, other); + CHECK_FIELD_EQ(*this, *p, max_degree_); + CHECK_FIELD_EQ(*this, *p, support_remove_); + CHECK_FIELD_EQ(*this, *p, remove_flag_bit_); + CHECK_FIELD_EQ(*this, *p, use_reverse_edges_); + CHECK_FIELD_EQ(*this, *p, support_duplicate_); return true; } diff --git a/src/datacell/multi_vector_datacell_parameter.h b/src/datacell/multi_vector_datacell_parameter.h index a3cf9347e6..bf69e05c3c 100644 --- a/src/datacell/multi_vector_datacell_parameter.h +++ b/src/datacell/multi_vector_datacell_parameter.h @@ -18,9 +18,9 @@ #include #include "flatten_interface_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" #include "quantization/fp32_quantizer_parameter.h" +#include "utils/param_compat_macros.h" #include "utils/pointer_define.h" namespace vsag { @@ -50,13 +50,7 @@ class MultiVectorDataCellParameter : public FlattenInterfaceParameter { bool CheckCompatibility(const vsag::ParamPtr& other) const override { - auto mv_param = std::dynamic_pointer_cast(other); - if (not mv_param) { - logger::error( - "MultiVectorDataCellParameter::CheckCompatibility: " - "other parameter is not MultiVectorDataCellParameter"); - return false; - } + PARAM_CAST_OR_RETURN(MultiVectorDataCellParameter, p, other); return true; } }; diff --git a/src/datacell/sparse_graph_datacell_parameter.cpp b/src/datacell/sparse_graph_datacell_parameter.cpp index 849299904d..a58fa6afde 100644 --- a/src/datacell/sparse_graph_datacell_parameter.cpp +++ b/src/datacell/sparse_graph_datacell_parameter.cpp @@ -15,7 +15,7 @@ #include "sparse_graph_datacell_parameter.h" -#include "impl/logger/logger.h" +#include "utils/param_compat_macros.h" namespace vsag { SparseGraphDatacellParameter::SparseGraphDatacellParameter() @@ -54,50 +54,12 @@ SparseGraphDatacellParameter::ToJson() const { bool SparseGraphDatacellParameter::CheckCompatibility(const ParamPtr& other) const { - auto graph_param = std::dynamic_pointer_cast(other); - if (not graph_param) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: other parameter is not a " - "SparseGraphDatacellParameter"); - return false; - } - if (max_degree_ != graph_param->max_degree_) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: max_degree_ mismatch: {} vs {}", - max_degree_, - graph_param->max_degree_); - return false; - } - if (support_delete_ != graph_param->support_delete_) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: support_delete_ mismatch: {} vs {}", - support_delete_, - graph_param->support_delete_); - return false; - } - if (remove_flag_bit_ != graph_param->remove_flag_bit_) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: remove_flag_bit_ mismatch: {} vs {}", - remove_flag_bit_, - graph_param->remove_flag_bit_); - return false; - } - if (support_duplicate_ != graph_param->support_duplicate_) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: support_duplicate_ mismatch: {} vs " - "{}", - support_duplicate_, - graph_param->support_duplicate_); - return false; - } - if (use_reverse_edges_ != graph_param->use_reverse_edges_) { - logger::error( - "SparseGraphDatacellParameter::CheckCompatibility: use_reverse_edges_ mismatch: {} " - "vs {}", - use_reverse_edges_, - graph_param->use_reverse_edges_); - return false; - } + PARAM_CAST_OR_RETURN(SparseGraphDatacellParameter, p, other); + CHECK_FIELD_EQ(*this, *p, max_degree_); + CHECK_FIELD_EQ(*this, *p, support_delete_); + CHECK_FIELD_EQ(*this, *p, remove_flag_bit_); + CHECK_FIELD_EQ(*this, *p, support_duplicate_); + CHECK_FIELD_EQ(*this, *p, use_reverse_edges_); return true; } } // namespace vsag diff --git a/src/datacell/sparse_vector_datacell_parameter.h b/src/datacell/sparse_vector_datacell_parameter.h index 35ccc8d09d..5925933a88 100644 --- a/src/datacell/sparse_vector_datacell_parameter.h +++ b/src/datacell/sparse_vector_datacell_parameter.h @@ -19,6 +19,7 @@ #include "flatten_interface.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" #include "utils/pointer_define.h" namespace vsag { @@ -55,14 +56,9 @@ class SparseVectorDataCellParameter : public FlattenInterfaceParameter { bool CheckCompatibility(const vsag::ParamPtr& other) const override { - auto sparse_param = std::dynamic_pointer_cast(other); - if (not sparse_param) { - logger::error( - "SparseVectorDataCellParameter::CheckCompatibility: " - "other parameter is not SparseVectorDataCellParameter"); - return false; - } - return this->quantizer_parameter->CheckCompatibility(sparse_param->quantizer_parameter); + PARAM_CAST_OR_RETURN(SparseVectorDataCellParameter, p, other); + CHECK_SUB_PARAM(*this, *p, quantizer_parameter); + return true; } }; } // namespace vsag diff --git a/src/impl/transform/vector_transformer_parameter.cpp b/src/impl/transform/vector_transformer_parameter.cpp index 39428e20d4..90d10fdce6 100644 --- a/src/impl/transform/vector_transformer_parameter.cpp +++ b/src/impl/transform/vector_transformer_parameter.cpp @@ -16,6 +16,7 @@ #include "vector_transformer_parameter.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -45,22 +46,10 @@ VectorTransformerParameter::ToJson() const { bool VectorTransformerParameter::CheckCompatibility(const ParamPtr& other) const { - auto param = std::dynamic_pointer_cast(other); - if (not param) { - logger::error( - "VectorTransformerParameter::CheckCompatibility: other parameter is not a " - "VectorTransformerParameter"); - return false; - } - if (pca_dim_ != param->pca_dim_) { - return false; - } - if (input_dim_ != param->input_dim_) { - return false; - } - if (mrle_dim_ != param->mrle_dim_) { - return false; - } + PARAM_CAST_OR_RETURN(VectorTransformerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, pca_dim_); + CHECK_FIELD_EQ(*this, *p, input_dim_); + CHECK_FIELD_EQ(*this, *p, mrle_dim_); return true; } diff --git a/src/quantization/product_quantization/pq_fastscan_quantizer_parameter.cpp b/src/quantization/product_quantization/pq_fastscan_quantizer_parameter.cpp index 5d61dda850..4e917258c1 100644 --- a/src/quantization/product_quantization/pq_fastscan_quantizer_parameter.cpp +++ b/src/quantization/product_quantization/pq_fastscan_quantizer_parameter.cpp @@ -15,8 +15,8 @@ #include "pq_fastscan_quantizer_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -42,21 +42,8 @@ PQFastScanQuantizerParameter::ToJson() const { bool PQFastScanQuantizerParameter::CheckCompatibility(const ParamPtr& other) const { - auto pq_fast_param = std::dynamic_pointer_cast(other); - if (not pq_fast_param) { - logger::error( - "PQFastScanQuantizerParameter::CheckCompatibility: " - "other is not PQFastScanQuantizerParameter"); - return false; - } - if (this->pq_dim_ != pq_fast_param->pq_dim_) { - logger::error( - "PQFastScanQuantizerParameter::CheckCompatibility: " - "pq_dim mismatch, this: {}, other: {}", - this->pq_dim_, - pq_fast_param->pq_dim_); - return false; - } + PARAM_CAST_OR_RETURN(PQFastScanQuantizerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, pq_dim_); return true; } } // namespace vsag diff --git a/src/quantization/product_quantization/product_quantizer_parameter.cpp b/src/quantization/product_quantization/product_quantizer_parameter.cpp index 82bc68d01b..e4a4ae3f34 100644 --- a/src/quantization/product_quantization/product_quantizer_parameter.cpp +++ b/src/quantization/product_quantization/product_quantizer_parameter.cpp @@ -15,8 +15,8 @@ #include "product_quantizer_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -48,29 +48,9 @@ ProductQuantizerParameter::ToJson() const { bool ProductQuantizerParameter::CheckCompatibility(const ParamPtr& other) const { - auto pq_other = std::dynamic_pointer_cast(other); - if (not pq_other) { - logger::error( - "ProductQuantizerParameter::CheckCompatibility: " - "other parameter is not a ProductQuantizerParameter"); - return false; - } - if (this->pq_dim_ != pq_other->pq_dim_) { - logger::error( - "ProductQuantizerParameter::CheckCompatibility: " - "pq_dim mismatch: {} vs {}", - this->pq_dim_, - pq_other->pq_dim_); - return false; - } - if (this->pq_bits_ != pq_other->pq_bits_) { - logger::error( - "ProductQuantizerParameter::CheckCompatibility: " - "pq_bits mismatch: {} vs {}", - this->pq_bits_, - pq_other->pq_bits_); - return false; - } + PARAM_CAST_OR_RETURN(ProductQuantizerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, pq_dim_); + CHECK_FIELD_EQ(*this, *p, pq_bits_); return true; } } // namespace vsag diff --git a/src/quantization/rabitq_quantization/rabitq_quantizer_parameter.cpp b/src/quantization/rabitq_quantization/rabitq_quantizer_parameter.cpp index 1732e943a6..74628d65a1 100644 --- a/src/quantization/rabitq_quantization/rabitq_quantizer_parameter.cpp +++ b/src/quantization/rabitq_quantization/rabitq_quantizer_parameter.cpp @@ -17,8 +17,8 @@ #include -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -93,59 +93,13 @@ RaBitQuantizerParameter::ToJson() const { bool RaBitQuantizerParameter::CheckCompatibility(const ParamPtr& other) const { - auto rabitq_param = std::dynamic_pointer_cast(other); - if (not rabitq_param) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: other parameter is not a " - "RaBitQuantizerParameter"); - return false; - } - if (this->pca_dim_ != rabitq_param->pca_dim_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: PCA dimensions do not match: {} vs {}", - this->pca_dim_, - rabitq_param->pca_dim_); - return false; - } - if (this->num_bits_per_dim_query_ != rabitq_param->num_bits_per_dim_query_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: Number of bits per dimension query do " - "not match: {} vs {}", - this->num_bits_per_dim_query_, - rabitq_param->num_bits_per_dim_query_); - return false; - } - if (this->num_bits_per_dim_base_ != rabitq_param->num_bits_per_dim_base_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: Number of bits per dimension base do " - "not match: {} vs {}", - this->num_bits_per_dim_base_, - rabitq_param->num_bits_per_dim_base_); - return false; - } - if (this->rabitq_version_ != rabitq_param->rabitq_version_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: RabitQ version does not match: {} vs {}", - this->rabitq_version_, - rabitq_param->rabitq_version_); - return false; - } - if (this->rabitq_error_rate_ != rabitq_param->rabitq_error_rate_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: RabitQ error rate does not match: {} " - "vs {}", - this->rabitq_error_rate_, - rabitq_param->rabitq_error_rate_); - return false; - } - if (this->use_fht_ != rabitq_param->use_fht_) { - logger::error( - "RaBitQuantizerParameter::CheckCompatibility: Use FHT flag does not match: {} vs {}", - this->use_fht_, - rabitq_param->use_fht_); - return false; - } - + PARAM_CAST_OR_RETURN(RaBitQuantizerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, pca_dim_); + CHECK_FIELD_EQ(*this, *p, num_bits_per_dim_query_); + CHECK_FIELD_EQ(*this, *p, num_bits_per_dim_base_); + CHECK_FIELD_EQ(*this, *p, rabitq_version_); + CHECK_FIELD_EQ(*this, *p, rabitq_error_rate_); + CHECK_FIELD_EQ(*this, *p, use_fht_); return true; } } // namespace vsag diff --git a/src/quantization/scalar_quantization/sq4_uniform_quantizer_parameter.cpp b/src/quantization/scalar_quantization/sq4_uniform_quantizer_parameter.cpp index e5f57dfb9d..0207ddf88d 100644 --- a/src/quantization/scalar_quantization/sq4_uniform_quantizer_parameter.cpp +++ b/src/quantization/scalar_quantization/sq4_uniform_quantizer_parameter.cpp @@ -15,8 +15,8 @@ #include "sq4_uniform_quantizer_parameter.h" -#include "impl/logger/logger.h" #include "inner_string_params.h" +#include "utils/param_compat_macros.h" namespace vsag { SQ4UniformQuantizerParameter::SQ4UniformQuantizerParameter() @@ -40,22 +40,8 @@ SQ4UniformQuantizerParameter::ToJson() const { } bool SQ4UniformQuantizerParameter::CheckCompatibility(const ParamPtr& other) const { - auto other_sq4_uniform_quantizer_parameter = - std::dynamic_pointer_cast(other); - if (not other_sq4_uniform_quantizer_parameter) { - logger::error( - "SQ4UniformQuantizerParameter::CheckCompatibility: " - "other is not SQ4UniformQuantizerParameter"); - return false; - } - if (this->trunc_rate_ != other_sq4_uniform_quantizer_parameter->trunc_rate_) { - logger::error( - "SQ4UniformQuantizerParameter::CheckCompatibility: " - "trunc_rate mismatch: {} vs {}", - this->trunc_rate_, - other_sq4_uniform_quantizer_parameter->trunc_rate_); - return false; - } + PARAM_CAST_OR_RETURN(SQ4UniformQuantizerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, trunc_rate_); return true; } } // namespace vsag diff --git a/src/quantization/transform_quantization/transform_quantizer_parameter.cpp b/src/quantization/transform_quantization/transform_quantizer_parameter.cpp index 8dc68ec011..e9d1b533dc 100644 --- a/src/quantization/transform_quantization/transform_quantizer_parameter.cpp +++ b/src/quantization/transform_quantization/transform_quantizer_parameter.cpp @@ -15,7 +15,7 @@ #include "transform_quantizer_parameter.h" -#include "impl/logger/logger.h" +#include "utils/param_compat_macros.h" namespace vsag { @@ -97,22 +97,9 @@ TransformQuantizerParameter::ToJson() const { bool TransformQuantizerParameter::CheckCompatibility(const ParamPtr& other) const { - auto tq_param = std::dynamic_pointer_cast(other); - if (not tq_param) { - logger::error( - "TransformQuantizerParameter::CheckCompatibility: other parameter is not a " - "TransformQuantizerParameter"); - return false; - } - if (tq_param->tq_chain_.size() != this->tq_chain_.size()) { - return false; - } - for (auto i = 0; i < tq_param->tq_chain_.size(); i++) { - if (this->tq_chain_[i] != tq_param->tq_chain_[i]) { - return false; - } - } + PARAM_CAST_OR_RETURN(TransformQuantizerParameter, p, other); + CHECK_FIELD_EQ(*this, *p, tq_chain_); return this->base_quantizer_json_[TYPE_KEY].GetString() == - tq_param->base_quantizer_json_[TYPE_KEY].GetString(); + p->base_quantizer_json_[TYPE_KEY].GetString(); } } // namespace vsag diff --git a/src/utils/param_compat_macros.h b/src/utils/param_compat_macros.h new file mode 100644 index 0000000000..606acbda4c --- /dev/null +++ b/src/utils/param_compat_macros.h @@ -0,0 +1,45 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "impl/logger/logger.h" + +// Cast `other` to the target parameter type; log and return false on type mismatch. +#define PARAM_CAST_OR_RETURN(Type, var, other) \ + auto var = std::dynamic_pointer_cast((other)); \ + if (not var) { \ + logger::error(#Type "::CheckCompatibility: type mismatch"); \ + return false; \ + } + +// Compare a scalar field; log the field name and return false on mismatch. +#define CHECK_FIELD_EQ(self, other, field) \ + if ((self).field != (other).field) { \ + logger::error(#field " mismatch"); \ + return false; \ + } + +// Recursively check a sub-parameter; log the field name and return false on incompatibility. +// Safely handles null pointers: both null is compatible, one null is incompatible. +#define CHECK_SUB_PARAM(self, other, field) \ + if (((self).field == nullptr) != ((other).field == nullptr)) { \ + logger::error(#field " incompatible (null mismatch)"); \ + return false; \ + } \ + if ((self).field && not(self).field->CheckCompatibility((other).field)) { \ + logger::error(#field " incompatible"); \ + return false; \ + }