diff --git a/be/src/exec/analytic-eval-node.cc b/be/src/exec/analytic-eval-node.cc index ab9f7f288..93960c9e7 100644 --- a/be/src/exec/analytic-eval-node.cc +++ b/be/src/exec/analytic-eval-node.cc @@ -235,6 +235,7 @@ Status AnalyticEvalNode::Open(RuntimeState* state) { AggFnEvaluator::Init(evaluators_, fn_ctxs_, curr_tuple_); // Allocate dummy_result_tuple_ even if AggFnEvaluator::Init() may have failed // as it is needed in Close(). + // TODO: move this to Prepare() dummy_result_tuple_ = Tuple::Create(result_tuple_desc_->byte_size(), mem_pool_.get()); // Check for failures during AggFnEvaluator::Init(). RETURN_IF_ERROR(state->GetQueryStatus()); diff --git a/be/src/exprs/string-functions.cc b/be/src/exprs/string-functions.cc index 542d1d850..f5249036d 100644 --- a/be/src/exprs/string-functions.cc +++ b/be/src/exprs/string-functions.cc @@ -77,6 +77,7 @@ StringVal StringFunctions::Space(FunctionContext* context, const BigIntVal& len) if (len.is_null) return StringVal::null(); if (len.val <= 0) return StringVal(); StringVal result(context, len.val); + if (UNLIKELY(result.is_null)) return StringVal::null(); memset(result.ptr, ' ', len.val); return result; } @@ -86,7 +87,7 @@ StringVal StringFunctions::Repeat( if (str.is_null || n.is_null) return StringVal::null(); if (str.len == 0 || n.val <= 0) return StringVal(); StringVal result(context, str.len * n.val); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); uint8_t* ptr = result.ptr; for (int64_t i = 0; i < n.val; ++i) { memcpy(ptr, str.ptr, str.len); @@ -104,7 +105,7 @@ StringVal StringFunctions::Lpad(FunctionContext* context, const StringVal& str, if (len.val <= str.len || pad.len == 0) return StringVal(str.ptr, len.val); StringVal result(context, len.val); - if (result.is_null) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); int padded_prefix_len = len.val - str.len; int pad_index = 0; int result_index = 0; @@ -132,7 +133,7 @@ StringVal StringFunctions::Rpad(FunctionContext* context, const StringVal& str, } StringVal result(context, len.val); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); memcpy(result.ptr, str.ptr, str.len); // Append chars of pad until desired length @@ -161,7 +162,7 @@ IntVal StringFunctions::CharLength(FunctionContext* context, const StringVal& st StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::tolower(str.ptr[i]); } @@ -171,7 +172,7 @@ StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::toupper(str.ptr[i]); } @@ -185,7 +186,7 @@ StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); uint8_t* result_ptr = result.ptr; bool word_start = true; for (int i = 0; i < str.len; ++i) { @@ -203,7 +204,7 @@ StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& st StringVal StringFunctions::Reverse(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); - if (UNLIKELY(result.is_null)) return result; + if (UNLIKELY(result.is_null)) return StringVal::null(); std::reverse_copy(str.ptr, str.ptr + str.len, result.ptr); return result; } @@ -305,7 +306,7 @@ IntVal StringFunctions::LocatePos(FunctionContext* context, const StringVal& sub StringSearch search(&substr_sv); // Input start_pos.val starts from 1. StringValue adjusted_str(reinterpret_cast(str.ptr) + start_pos.val - 1, - str.len - start_pos.val + 1); + str.len - start_pos.val + 1); int32_t match_pos = search.Search(&adjusted_str); if (match_pos >= 0) { // Hive returns the position in the original string starting from 1. @@ -318,6 +319,7 @@ IntVal StringFunctions::LocatePos(FunctionContext* context, const StringVal& sub // The caller owns the returned regex. Returns NULL if the pattern could not be compiled. re2::RE2* CompileRegex(const StringVal& pattern, string* error_str, const StringVal& match_parameter) { + DCHECK(error_str != NULL); re2::StringPiece pattern_sp(reinterpret_cast(pattern.ptr), pattern.len); re2::RE2::Options options; // Disable error logging in case e.g. every row causes an error @@ -547,6 +549,7 @@ StringVal StringFunctions::Concat(FunctionContext* context, int num_children, StringVal StringFunctions::ConcatWs(FunctionContext* context, const StringVal& sep, int num_children, const StringVal* strs) { DCHECK_GE(num_children, 1); + DCHECK(strs != NULL); if (sep.is_null) return StringVal::null(); // Pass through if there's only one argument @@ -561,9 +564,10 @@ StringVal StringFunctions::ConcatWs(FunctionContext* context, const StringVal& s total_size += sep.len + strs[i].len; } StringVal result(context, total_size); - uint8_t* ptr = result.ptr; + if (UNLIKELY(result.is_null)) return StringVal::null(); // Loop again to append the data. + uint8_t* ptr = result.ptr; memcpy(ptr, strs[0].ptr, strs[0].len); ptr += strs[0].len; for (int32_t i = 1; i < num_children; ++i) { @@ -662,7 +666,7 @@ void StringFunctions::ParseUrlClose( } StringVal StringFunctions::ParseUrlKey(FunctionContext* ctx, const StringVal& url, - const StringVal& part, const StringVal& key) { + const StringVal& part, const StringVal& key) { if (url.is_null || part.is_null || key.is_null) return StringVal::null(); void* state = ctx->GetFunctionState(FunctionContext::FRAGMENT_LOCAL); UrlParser::UrlPart url_part; @@ -732,6 +736,7 @@ StringVal StringFunctions::BTrimString(FunctionContext* ctx, // characters here instead of using the bitset from function context. if (!ctx->IsArgConstant(1)) { unique_chars->reset(); + DCHECK(chars_to_trim.len != 0 || chars_to_trim.is_null); for (int32_t i = 0; i < chars_to_trim.len; ++i) { unique_chars->set(static_cast(chars_to_trim.ptr[i]), true); } @@ -751,8 +756,9 @@ StringVal StringFunctions::BTrimString(FunctionContext* ctx, } // Similar to strstr() except that the strings are not null-terminated -static char* locate_substring(char* haystack, int hay_len, char* needle, int needle_len) { +static char* LocateSubstring(char* haystack, int hay_len, const char* needle, int needle_len) { DCHECK_GT(needle_len, 0); + DCHECK(haystack != NULL && needle != NULL); for (int i = 0; i < hay_len - needle_len + 1; ++i) { char* possible_needle = haystack + i; if (strncmp(possible_needle, needle, needle_len) == 0) return possible_needle; @@ -776,7 +782,7 @@ StringVal StringFunctions::SplitPart(FunctionContext* context, char* delimiter = reinterpret_cast(delim.ptr); for (int cur_pos = 1; ; ++cur_pos) { int remaining_len = str.len - (str_part - str_start); - char* delim_ref = locate_substring(str_part, remaining_len, delimiter, delim.len); + char* delim_ref = LocateSubstring(str_part, remaining_len, delimiter, delim.len); if (delim_ref == NULL) { if (cur_pos == field_pos) { return StringVal(reinterpret_cast(str_part), remaining_len); diff --git a/be/src/exprs/timestamp-functions.cc b/be/src/exprs/timestamp-functions.cc index 1211f3232..98084199b 100644 --- a/be/src/exprs/timestamp-functions.cc +++ b/be/src/exprs/timestamp-functions.cc @@ -149,6 +149,7 @@ StringVal TimestampFunctions::StringValFromTimestamp(FunctionContext* context, int buff_len = dt_ctx->fmt_out_len + 1; StringVal result(context, buff_len); + if (UNLIKELY(result.is_null)) return StringVal::null(); result.len = tv.Format(*dt_ctx, buff_len, reinterpret_cast(result.ptr)); if (result.len <= 0) return StringVal::null(); return result; diff --git a/be/src/exprs/udf-builtins.cc b/be/src/exprs/udf-builtins.cc index 128ed8500..0de2de331 100644 --- a/be/src/exprs/udf-builtins.cc +++ b/be/src/exprs/udf-builtins.cc @@ -120,6 +120,7 @@ struct TruncUnit { // Returns the TruncUnit for the given string TruncUnit::Type StrToTruncUnit(FunctionContext* ctx, const StringVal& unit_str) { StringVal unit = UdfBuiltins::Lower(ctx, unit_str); + if (UNLIKELY(unit.is_null)) return TruncUnit::UNIT_INVALID; if ((unit == "syyyy") || (unit == "yyyy") || (unit == "year") || (unit == "syear") || (unit == "yyy") || (unit == "yy") || (unit == "y")) { return TruncUnit::YEAR; @@ -336,6 +337,7 @@ void UdfBuiltins::TruncClose(FunctionContext* ctx, // Returns the TExtractField for the given unit TExtractField::type StrToExtractField(FunctionContext* ctx, const StringVal& unit_str) { StringVal unit = UdfBuiltins::Lower(ctx, unit_str); + if (UNLIKELY(unit.is_null)) return TExtractField::INVALID_FIELD; if (unit == "year") return TExtractField::YEAR; if (unit == "month") return TExtractField::MONTH; if (unit == "day") return TExtractField::DAY; @@ -480,6 +482,7 @@ bool ValidateMADlibVector(FunctionContext* context, const StringVal& arr) { StringVal UdfBuiltins::ToVector(FunctionContext* context, int n, const DoubleVal* vals) { StringVal s(context, n * sizeof(double)); + if (UNLIKELY(s.is_null)) return StringVal::null(); double* darr = reinterpret_cast(s.ptr); for (int i = 0; i < n; ++i) { if (vals[i].is_null) { @@ -504,6 +507,7 @@ StringVal UdfBuiltins::PrintVector(FunctionContext* context, const StringVal& ar ss << ">"; const string& str = ss.str(); StringVal result(context, str.size()); + if (UNLIKELY(result.is_null)) return StringVal::null(); memcpy(result.ptr, str.c_str(), str.size()); return result; } @@ -553,6 +557,7 @@ StringVal UdfBuiltins::EncodeVector(FunctionContext* context, const StringVal& a double* darr = reinterpret_cast(arr.ptr); int len = arr.len / sizeof(double); StringVal result(context, arr.len); + if (UNLIKELY(result.is_null)) return StringVal::null(); memcpy(result.ptr, darr, arr.len); InplaceDoubleEncode(reinterpret_cast(result.ptr), len); return result; @@ -561,6 +566,7 @@ StringVal UdfBuiltins::EncodeVector(FunctionContext* context, const StringVal& a StringVal UdfBuiltins::DecodeVector(FunctionContext* context, const StringVal& arr) { if (arr.is_null) return StringVal::null(); StringVal result(context, arr.len); + if (UNLIKELY(result.is_null)) return StringVal::null(); memcpy(result.ptr, arr.ptr, arr.len); InplaceDoubleDecode(reinterpret_cast(result.ptr), arr.len); return result; diff --git a/be/src/exprs/utility-functions.cc b/be/src/exprs/utility-functions.cc index b990bdaca..13bacf114 100644 --- a/be/src/exprs/utility-functions.cc +++ b/be/src/exprs/utility-functions.cc @@ -32,13 +32,13 @@ using namespace strings; namespace impala { BigIntVal UtilityFunctions::FnvHashString(FunctionContext* ctx, - const StringVal& input_val) { + const StringVal& input_val) { if (input_val.is_null) return BigIntVal::null(); return BigIntVal(HashUtil::FnvHash64(input_val.ptr, input_val.len, HashUtil::FNV_SEED)); } BigIntVal UtilityFunctions::FnvHashTimestamp(FunctionContext* ctx, - const TimestampVal& input_val) { + const TimestampVal& input_val) { if (input_val.is_null) return BigIntVal::null(); TimestampValue tv = TimestampValue::FromTimestampVal(input_val); return BigIntVal(HashUtil::FnvHash64(&tv, 12, HashUtil::FNV_SEED)); @@ -54,7 +54,7 @@ BigIntVal UtilityFunctions::FnvHash(FunctionContext* ctx, const T& input_val) { // Note that this only hashes the unscaled value and not the scale or precision, so this // function is only valid when used over a single decimal type. BigIntVal UtilityFunctions::FnvHashDecimal(FunctionContext* ctx, - const DecimalVal& input_val) { + const DecimalVal& input_val) { if (input_val.is_null) return BigIntVal::null(); const FunctionContext::TypeDesc& input_type = *ctx->GetArgType(0); int byte_size = ColumnType::GetDecimalByteSize(input_type.precision); diff --git a/be/src/runtime/string-search.h b/be/src/runtime/string-search.h index 4a741e9b1..d15d65e2e 100644 --- a/be/src/runtime/string-search.h +++ b/be/src/runtime/string-search.h @@ -109,7 +109,7 @@ class StringSearch { /// Returns -1 if the pattern is not found int Search(const StringValue* str) const { // Special cases - if (!str || !pattern_ || pattern_->len == 0) { + if (str == NULL || pattern_ == NULL || pattern_->len == 0) { return -1; } diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index 47b36e2b9..83551c19b 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -548,7 +548,9 @@ struct TimestampVal : public AnyVal { /// empty string (len == 0). struct StringVal : public AnyVal { - static const int MAX_LENGTH = (1 << 30); + // It's important to keep this as unsigned to avoid comparing with negative number + // in case of overflow. + static const unsigned MAX_LENGTH = (1 << 30); int len; uint8_t* ptr; diff --git a/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-init.test b/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-init.test index cdf25aa4c..338f8bd64 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-init.test +++ b/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-init.test @@ -1,5 +1,6 @@ ==== ---- QUERY +# TODO: IMPALA-3350: Add 'group by' to these tests to exercise different code paths. select ndv(string_col) from functional.alltypes ---- CATCH FunctionContext::Allocate() failed to allocate 1024 bytes. diff --git a/testdata/workloads/functional-query/queries/QueryTest/large_strings.test b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test index b25e5c80e..5c9086933 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/large_strings.test +++ b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test @@ -112,3 +112,82 @@ select l_comment from tpch_parquet.lineitem) a ---- CATCH Memory limit exceeded ==== +---- QUERY +#IMPALA-3350: Results of string functions can exceed 1GB. +select length(concat_ws(',', s, s, s, s)) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(repeat(s, 10)) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(lpad(s, 1073741830, '!')) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(rpad(s, 1073741830, '~')) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select space(1073741830); +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(regexp_replace(s, '.', '++++++++')) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select trunc(timestamp_col, space(1073741830)) from functional.alltypes +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select extract(timestamp_col, space(1073741830)) from functional.alltypes +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(madlib_encode_vector(concat_ws(',', s, s, s, s))) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +===== +---- QUERY +select length(madlib_decode_vector(concat_ws(',', s, s, s, s))) from ( + select group_concat(l_comment, "!") s from ( + select l_comment from tpch.lineitem union all + select l_comment from tpch.lineitem) t1 + ) t2 +---- CATCH +String length larger than allowed limit of 1 GB character data +=====