Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,10 @@ TransformQuantizer<QuantTmpl, metric>::ProcessQueryImpl(
const float* query, Computer<TransformQuantizer>& computer) const {
// 0. allocate
try {
computer.inner_computer_->buf_ =
reinterpret_cast<uint8_t*>(this->allocator_->Allocate(this->query_code_size_));
if (computer.inner_computer_->buf_ == nullptr) {
computer.inner_computer_->buf_ =
reinterpret_cast<uint8_t*>(this->allocator_->Allocate(this->query_code_size_));
}
Comment on lines +287 to +290

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While checking for nullptr before allocation prevents repeated allocation, we should also handle the case where the allocator returns nullptr instead of throwing std::bad_alloc (which is common for some custom allocators or when exceptions are disabled). Adding a nullptr check immediately after allocation ensures robust error handling and prevents potential null pointer dereferences later in ExecuteChainTransform.

        if (computer.inner_computer_->buf_ == nullptr) {
            computer.inner_computer_->buf_ =
                reinterpret_cast<uint8_t*>(this->allocator_->Allocate(this->query_code_size_));
            if (computer.inner_computer_->buf_ == nullptr) {
                throw VsagException(ErrorType::NO_ENOUGH_MEMORY, "alloc return nullptr when init computer buf");
            }
        }

} catch (const std::bad_alloc& e) {
computer.inner_computer_->buf_ = nullptr;
throw VsagException(ErrorType::NO_ENOUGH_MEMORY, "bad alloc when init computer buf");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,43 @@ TEST_CASE("TQ Serialize and Deserialize", "[ut][TransformQuantizer]") {
}
}
}

TEST_CASE("TQ Repeated SetQuery No Leak", "[ut][TransformQuantizer]") {
constexpr MetricType metric = MetricType::METRIC_TYPE_L2SQR;
uint64_t dim = 128;
uint64_t count = 200;

auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto param = std::make_shared<TransformQuantizerParameter>();
static constexpr const char* param_template = R"(
{{
"tq_chain": "rom, fp32",
"pca_dim": {},
"mrle_dim": {}
}}
)";
auto param_str = fmt::format(param_template, dim, dim - 1);
auto param_json = vsag::JsonType::Parse(param_str);
param->FromJson(param_json);

IndexCommonParam common_param;
common_param.allocator_ = allocator;
common_param.dim_ = dim;
TransformQuantizer<FP32Quantizer<metric>, metric> quantizer(param, common_param);

auto vecs = fixtures::generate_vectors(count, dim);
quantizer.Train(vecs.data(), count);

auto computer = quantizer.FactoryComputer();
auto queries = fixtures::generate_vectors(10, dim, false, 42);

for (int i = 0; i < 10; ++i) {
computer->SetQuery(queries.data() + i * dim);
}

std::vector<uint8_t> codes(quantizer.GetCodeSize());
quantizer.EncodeOne(vecs.data(), codes.data());
float dist = 0;
quantizer.ComputeDist(*computer, codes.data(), &dist);
REQUIRE(dist >= 0);
}
Loading