Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 138 additions & 68 deletions cpp/src/gandiva/precompiled/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

// String functions
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging_internal.h"
#include "arrow/util/value_parsing.h"

Expand Down Expand Up @@ -1924,9 +1925,19 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len,
*out_len = 0;
return "";
}

gdv_int32 double_len = 0;
gdv_int32 alloc_len = 0;
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(in_len, 2, &double_len)) ||
ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(double_len, 2, &alloc_len))) {
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return "";
}

// try to allocate double size output string (worst case)
auto out =
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2));
auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -2424,6 +2435,71 @@ const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text,
return ret;
}

struct ConcatWsLengthState {
gdv_int32 total_length = 0;
gdv_int32 valid_count = 0;
};

FORCE_INLINE
bool concat_ws_length_error(gdv_int64 context, const char* message, bool* out_valid,
gdv_int32* out_len) {
gdv_fn_context_set_error_msg(context, message);
*out_len = 0;
*out_valid = false;
return false;
}

FORCE_INLINE
bool concat_ws_accumulate_word_length(gdv_int64 context, ConcatWsLengthState* state,
gdv_int32 word_len, bool word_validity,
bool* out_valid, gdv_int32* out_len) {
if (!word_validity) {
return true;
}

if (ARROW_PREDICT_FALSE(word_len < 0)) {
return concat_ws_length_error(context, "Invalid (negative) data length", out_valid,
out_len);
}

gdv_int32 total_length = 0;
if (ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(
state->total_length, word_len, &total_length))) {
return concat_ws_length_error(context, "Would overflow maximum output size", out_valid,
out_len);
}

state->total_length = total_length;
state->valid_count++;
return true;
}

FORCE_INLINE
bool concat_ws_finish_length(gdv_int64 context, ConcatWsLengthState* state,
gdv_int32 separator_len, bool* out_valid,
gdv_int32* out_len) {
if (ARROW_PREDICT_FALSE(separator_len < 0)) {
return concat_ws_length_error(context, "Invalid (negative) data length", out_valid,
out_len);
}

if (state->valid_count > 1) {
gdv_int32 separators_length = 0;
gdv_int32 total_length = 0;
if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow(
separator_len, state->valid_count - 1, &separators_length)) ||
ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(
state->total_length, separators_length, &total_length))) {
return concat_ws_length_error(context, "Would overflow maximum output size",
out_valid, out_len);
}
state->total_length = total_length;
}

*out_len = state->total_length;
return true;
}

FORCE_INLINE
void concat_word(char* out_buf, int* out_idx, const char* in_buf, int in_len,
bool in_validity, const char* separator, int separator_len,
Expand Down Expand Up @@ -2451,24 +2527,22 @@ const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
const char* word2, int32_t word2_len, bool word2_validity,
bool* out_valid, int32_t* out_len) {
*out_len = 0;
int numValidInput = 0;
// If separator is null, always return null
if (!separator_validity) {
*out_len = 0;
*out_valid = false;
return "";
}

if (word1_validity) {
*out_len += word1_len;
numValidInput++;
}
if (word2_validity) {
*out_len += word2_len;
numValidInput++;
ConcatWsLengthState state;
if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity,
out_valid, out_len) ||
!concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) {
return "";
}

*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
if (*out_len == 0) {
*out_valid = true;
return "";
Expand Down Expand Up @@ -2503,29 +2577,24 @@ const char* concat_ws_utf8_utf8_utf8(
const char* word2, int32_t word2_len, bool word2_validity, const char* word3,
int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) {
*out_len = 0;
int numValidInput = 0;
// If separator is null, always return null
if (!separator_validity) {
*out_len = 0;
*out_valid = false;
return "";
}

if (word1_validity) {
*out_len += word1_len;
numValidInput++;
}
if (word2_validity) {
*out_len += word2_len;
numValidInput++;
}
if (word3_validity) {
*out_len += word3_len;
numValidInput++;
ConcatWsLengthState state;
if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity,
out_valid, out_len) ||
!concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) {
return "";
}

*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);

if (*out_len == 0) {
*out_len = 0;
*out_valid = true;
Expand Down Expand Up @@ -2564,31 +2633,25 @@ const char* concat_ws_utf8_utf8_utf8_utf8(
int32_t word3_len, bool word3_validity, const char* word4, int32_t word4_len,
bool word4_validity, bool* out_valid, int32_t* out_len) {
*out_len = 0;
int numValidInput = 0;
// If separator is null, always return null
if (!separator_validity) {
*out_len = 0;
*out_valid = false;
return "";
}
if (word1_validity) {
*out_len += word1_len;
numValidInput++;
}
if (word2_validity) {
*out_len += word2_len;
numValidInput++;
}
if (word3_validity) {
*out_len += word3_len;
numValidInput++;
}
if (word4_validity) {
*out_len += word4_len;
numValidInput++;
}

*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
ConcatWsLengthState state;
if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word4_len, word4_validity,
out_valid, out_len) ||
!concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) {
return "";
}

if (*out_len == 0) {
*out_len = 0;
Expand Down Expand Up @@ -2631,35 +2694,27 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8(
bool word4_validity, const char* word5, int32_t word5_len, bool word5_validity,
bool* out_valid, int32_t* out_len) {
*out_len = 0;
int numValidInput = 0;
// If separator is null, always return null
if (!separator_validity) {
*out_len = 0;
*out_valid = false;
return "";
}
if (word1_validity) {
*out_len += word1_len;
numValidInput++;
}
if (word2_validity) {
*out_len += word2_len;
numValidInput++;
}
if (word3_validity) {
*out_len += word3_len;
numValidInput++;
}
if (word4_validity) {
*out_len += word4_len;
numValidInput++;
}
if (word5_validity) {
*out_len += word5_len;
numValidInput++;
}

*out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
ConcatWsLengthState state;
if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word4_len, word4_validity,
out_valid, out_len) ||
!concat_ws_accumulate_word_length(context, &state, word5_len, word5_validity,
out_valid, out_len) ||
!concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) {
return "";
}

if (*out_len == 0) {
*out_len = 0;
Expand Down Expand Up @@ -2829,8 +2884,23 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len,
return "";
}

auto ret =
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len * 2 + 1));
if (ARROW_PREDICT_FALSE(text_len < 0)) {
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
*out_len = 0;
return "";
}

int32_t hex_len = 0;
int32_t alloc_len = 0;
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(text_len, 2, &hex_len)) ||
ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(hex_len, 1, &alloc_len))) {
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return "";
}

auto ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_len));

if (ret == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
Expand All @@ -2839,7 +2909,7 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len,
}

uint32_t ret_index = 0;
uint32_t max_len = static_cast<uint32_t>(text_len) * 2;
uint32_t max_len = static_cast<uint32_t>(hex_len);
uint32_t max_char_to_write = 4;

for (gdv_int32 i = 0; i < text_len; i++) {
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/gandiva/precompiled/string_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,20 @@ TEST(TestStringOps, TestQuote) {
out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''");
EXPECT_FALSE(ctx.has_error());

out_str = quote_utf8(ctx_ptr, "abc", std::numeric_limits<int32_t>::max() / 2 + 1,
&out_len);
EXPECT_STREQ(out_str, "");
EXPECT_EQ(out_len, 0);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

out_str =
quote_utf8(ctx_ptr, "abc", std::numeric_limits<int32_t>::max() / 2, &out_len);
EXPECT_STREQ(out_str, "");
EXPECT_EQ(out_len, 0);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();
}

TEST(TestStringOps, TestLtrim) {
Expand Down Expand Up @@ -2298,6 +2312,22 @@ TEST(TestStringOps, TestConcatWs) {
EXPECT_EQ(std::string(out, out_len), "hey");
EXPECT_EQ(out_result, true);

out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, -1, true, word2,
word2_len, false, &out_result, &out_len);
EXPECT_STREQ(out, "");
EXPECT_EQ(out_len, 0);
EXPECT_EQ(out_result, false);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

out = concat_ws_utf8_utf8(ctx_ptr, separator, -1, true, word1, word1_len, true, word2,
word2_len, false, &out_result, &out_len);
EXPECT_STREQ(out, "");
EXPECT_EQ(out_len, 0);
EXPECT_EQ(out_result, false);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

separator = "#";
sep_len = static_cast<int32_t>(strlen(separator));
const char* word3 = "wow";
Expand All @@ -2309,6 +2339,15 @@ TEST(TestStringOps, TestConcatWs) {
EXPECT_EQ(std::string(out, out_len), "hey#hello#wow");
EXPECT_EQ(out_result, true);

out = concat_ws_utf8_utf8_utf8(
ctx_ptr, separator, std::numeric_limits<int32_t>::max() / 2 + 1, true, "", 0,
true, "", 0, true, "", 0, true, &out_result, &out_len);
EXPECT_STREQ(out, "");
EXPECT_EQ(out_len, 0);
EXPECT_EQ(out_result, false);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, "", 0, true, word2,
word2_len, false, word3, word3_len, true, &out_result,
&out_len);
Expand Down Expand Up @@ -2498,6 +2537,13 @@ TEST(TestStringOps, TestToHex) {
output = std::string(out_str, out_len);
EXPECT_EQ(out_len, 2 * in_len);
EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572");

out_str = to_hex_binary(ctx_ptr, "A", std::numeric_limits<int32_t>::max() / 2 + 1,
&out_len);
EXPECT_STREQ(out_str, "");
EXPECT_EQ(out_len, 0);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();
}

TEST(TestStringOps, TestToHexInt64) {
Expand Down