diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 035d3c8c62e1..e2ed48cd92b1 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -16,6 +16,7 @@ // under the License. // String functions +#include "arrow/util/int_util_overflow.h" #include "arrow/util/logging_internal.h" #include "arrow/util/value_parsing.h" @@ -1924,9 +1925,19 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len, *out_len = 0; return ""; } + + gdv_int32 double_len = 0; + gdv_int32 alloc_len = 0; + if (ARROW_PREDICT_FALSE( + arrow::internal::MultiplyWithOverflow(in_len, 2, &double_len)) || + ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(double_len, 2, &alloc_len))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + // try to allocate double size output string (worst case) - auto out = - reinterpret_cast(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2)); + auto out = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_len)); if (out == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); *out_len = 0; @@ -2424,6 +2435,71 @@ const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text, return ret; } +struct ConcatWsLengthState { + gdv_int32 total_length = 0; + gdv_int32 valid_count = 0; +}; + +FORCE_INLINE +bool concat_ws_length_error(gdv_int64 context, const char* message, bool* out_valid, + gdv_int32* out_len) { + gdv_fn_context_set_error_msg(context, message); + *out_len = 0; + *out_valid = false; + return false; +} + +FORCE_INLINE +bool concat_ws_accumulate_word_length(gdv_int64 context, ConcatWsLengthState* state, + gdv_int32 word_len, bool word_validity, + bool* out_valid, gdv_int32* out_len) { + if (!word_validity) { + return true; + } + + if (ARROW_PREDICT_FALSE(word_len < 0)) { + return concat_ws_length_error(context, "Invalid (negative) data length", out_valid, + out_len); + } + + gdv_int32 total_length = 0; + if (ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(state->total_length, word_len, + &total_length))) { + return concat_ws_length_error(context, "Would overflow maximum output size", + out_valid, out_len); + } + + state->total_length = total_length; + state->valid_count++; + return true; +} + +FORCE_INLINE +bool concat_ws_finish_length(gdv_int64 context, ConcatWsLengthState* state, + gdv_int32 separator_len, bool* out_valid, + gdv_int32* out_len) { + if (ARROW_PREDICT_FALSE(separator_len < 0)) { + return concat_ws_length_error(context, "Invalid (negative) data length", out_valid, + out_len); + } + + if (state->valid_count > 1) { + gdv_int32 separators_length = 0; + gdv_int32 total_length = 0; + if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow( + separator_len, state->valid_count - 1, &separators_length)) || + ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow( + state->total_length, separators_length, &total_length))) { + return concat_ws_length_error(context, "Would overflow maximum output size", + out_valid, out_len); + } + state->total_length = total_length; + } + + *out_len = state->total_length; + return true; +} + FORCE_INLINE void concat_word(char* out_buf, int* out_idx, const char* in_buf, int in_len, bool in_validity, const char* separator, int separator_len, @@ -2451,7 +2527,6 @@ const char* concat_ws_utf8_utf8(int64_t context, const char* separator, const char* word2, int32_t word2_len, bool word2_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; @@ -2459,16 +2534,15 @@ const char* concat_ws_utf8_utf8(int64_t context, const char* separator, return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; + ConcatWsLengthState state; + if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity, + out_valid, out_len) || + !concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) { + return ""; } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); if (*out_len == 0) { *out_valid = true; return ""; @@ -2503,7 +2577,6 @@ const char* concat_ws_utf8_utf8_utf8( const char* word2, int32_t word2_len, bool word2_validity, const char* word3, int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; @@ -2511,21 +2584,17 @@ const char* concat_ws_utf8_utf8_utf8( return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; + ConcatWsLengthState state; + if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity, + out_valid, out_len) || + !concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) { + return ""; } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); - if (*out_len == 0) { *out_len = 0; *out_valid = true; @@ -2564,31 +2633,25 @@ const char* concat_ws_utf8_utf8_utf8_utf8( int32_t word3_len, bool word3_validity, const char* word4, int32_t word4_len, bool word4_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; *out_valid = false; return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; - } - if (word4_validity) { - *out_len += word4_len; - numValidInput++; - } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); + ConcatWsLengthState state; + if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word4_len, word4_validity, + out_valid, out_len) || + !concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) { + return ""; + } if (*out_len == 0) { *out_len = 0; @@ -2631,35 +2694,27 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8( bool word4_validity, const char* word5, int32_t word5_len, bool word5_validity, bool* out_valid, int32_t* out_len) { *out_len = 0; - int numValidInput = 0; // If separator is null, always return null if (!separator_validity) { *out_len = 0; *out_valid = false; return ""; } - if (word1_validity) { - *out_len += word1_len; - numValidInput++; - } - if (word2_validity) { - *out_len += word2_len; - numValidInput++; - } - if (word3_validity) { - *out_len += word3_len; - numValidInput++; - } - if (word4_validity) { - *out_len += word4_len; - numValidInput++; - } - if (word5_validity) { - *out_len += word5_len; - numValidInput++; - } - *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0); + ConcatWsLengthState state; + if (!concat_ws_accumulate_word_length(context, &state, word1_len, word1_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word2_len, word2_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word3_len, word3_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word4_len, word4_validity, + out_valid, out_len) || + !concat_ws_accumulate_word_length(context, &state, word5_len, word5_validity, + out_valid, out_len) || + !concat_ws_finish_length(context, &state, separator_len, out_valid, out_len)) { + return ""; + } if (*out_len == 0) { *out_len = 0; @@ -2829,8 +2884,22 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len, return ""; } - auto ret = - reinterpret_cast(gdv_fn_context_arena_malloc(context, text_len * 2 + 1)); + if (ARROW_PREDICT_FALSE(text_len < 0)) { + gdv_fn_context_set_error_msg(context, "Invalid (negative) data length"); + *out_len = 0; + return ""; + } + + int32_t hex_len = 0; + int32_t alloc_len = 0; + if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow(text_len, 2, &hex_len)) || + ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow(hex_len, 1, &alloc_len))) { + gdv_fn_context_set_error_msg(context, "Would overflow maximum output size"); + *out_len = 0; + return ""; + } + + auto ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, alloc_len)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); @@ -2839,7 +2908,7 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len, } uint32_t ret_index = 0; - uint32_t max_len = static_cast(text_len) * 2; + uint32_t max_len = static_cast(hex_len); uint32_t max_char_to_write = 4; for (gdv_int32 i = 0; i < text_len; i++) { diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index d57eb437530c..e0e990cfa15f 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -1165,6 +1165,21 @@ TEST(TestStringOps, TestQuote) { out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len); EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''"); EXPECT_FALSE(ctx.has_error()); + + out_str = + quote_utf8(ctx_ptr, "abc", std::numeric_limits::max() / 2 + 1, &out_len); + EXPECT_STREQ(out_str, ""); + EXPECT_EQ(out_len, 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); + + out_str = quote_utf8(ctx_ptr, "abc", std::numeric_limits::max() / 2, &out_len); + EXPECT_STREQ(out_str, ""); + EXPECT_EQ(out_len, 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); } TEST(TestStringOps, TestLtrim) { @@ -2298,6 +2313,22 @@ TEST(TestStringOps, TestConcatWs) { EXPECT_EQ(std::string(out, out_len), "hey"); EXPECT_EQ(out_result, true); + out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, -1, true, word2, + word2_len, false, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); + + out = concat_ws_utf8_utf8(ctx_ptr, separator, -1, true, word1, word1_len, true, word2, + word2_len, false, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length")); + ctx.Reset(); + separator = "#"; sep_len = static_cast(strlen(separator)); const char* word3 = "wow"; @@ -2309,6 +2340,16 @@ TEST(TestStringOps, TestConcatWs) { EXPECT_EQ(std::string(out, out_len), "hey#hello#wow"); EXPECT_EQ(out_result, true); + out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, + std::numeric_limits::max() / 2 + 1, true, "", 0, + true, "", 0, true, "", 0, true, &out_result, &out_len); + EXPECT_STREQ(out, ""); + EXPECT_EQ(out_len, 0); + EXPECT_EQ(out_result, false); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); + out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, "", 0, true, word2, word2_len, false, word3, word3_len, true, &out_result, &out_len); @@ -2498,6 +2539,14 @@ TEST(TestStringOps, TestToHex) { output = std::string(out_str, out_len); EXPECT_EQ(out_len, 2 * in_len); EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572"); + + out_str = + to_hex_binary(ctx_ptr, "A", std::numeric_limits::max() / 2 + 1, &out_len); + EXPECT_STREQ(out_str, ""); + EXPECT_EQ(out_len, 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Would overflow maximum output size")); + ctx.Reset(); } TEST(TestStringOps, TestToHexInt64) {