diff --git a/be/src/exprs/aggregate-functions.cc b/be/src/exprs/aggregate-functions.cc index ce4b07294..18d824494 100644 --- a/be/src/exprs/aggregate-functions.cc +++ b/be/src/exprs/aggregate-functions.cc @@ -127,9 +127,7 @@ StringVal ToStringVal(FunctionContext* context, T val) { stringstream ss; ss << val; const string &str = ss.str(); - StringVal string_val(context, str.size()); - memcpy(string_val.ptr, str.c_str(), str.size()); - return string_val; + return StringVal::CopyFrom(context, reinterpret_cast(str.c_str()), str.size()); } // Delimiter to use if the separator is NULL. @@ -159,9 +157,7 @@ void AggregateFunctions::InitZero(FunctionContext*, DecimalVal* dst) { StringVal AggregateFunctions::StringValGetValue( FunctionContext* ctx, const StringVal& src) { if (src.is_null) return src; - StringVal result(ctx, src.len); - memcpy(result.ptr, src.ptr, src.len); - return result; + return StringVal::CopyFrom(ctx, src.ptr, src.len); } StringVal AggregateFunctions::StringValSerializeOrFinalize( @@ -584,13 +580,7 @@ void AggregateFunctions::StringConcatUpdate(FunctionContext* ctx, *result = StringVal(ctx->Allocate(header_len), header_len); *reinterpret_cast(result->ptr) = sep->len; } - int new_len = result->len + sep->len + src.len; - result->ptr = ctx->Reallocate(result->ptr, new_len); - memcpy(result->ptr + result->len, sep->ptr, sep->len); - result->len += sep->len; - memcpy(result->ptr + result->len, src.ptr, src.len); - result->len += src.len; - DCHECK(result->len == new_len); + result->Append(ctx, sep->ptr, sep->len, src.ptr, src.len); } void AggregateFunctions::StringConcatMerge(FunctionContext* ctx, @@ -600,15 +590,12 @@ void AggregateFunctions::StringConcatMerge(FunctionContext* ctx, if (result->is_null) { // Copy the header from the first intermediate value. *result = StringVal(ctx->Allocate(header_len), header_len); + if (result->is_null) return; *reinterpret_cast(result->ptr) = *reinterpret_cast(src.ptr); } // Append the string portion of the intermediate src to result (omit src's header). - int new_len = result->len + src.len - header_len; - result->ptr = ctx->Reallocate(result->ptr, new_len); - memcpy(result->ptr + result->len, src.ptr + header_len, src.len - header_len); - result->len += src.len - header_len; - DCHECK(result->len == new_len); + result->Append(ctx, src.ptr + header_len, src.len - header_len); } StringVal AggregateFunctions::StringConcatFinalize(FunctionContext* ctx, @@ -619,8 +606,8 @@ StringVal AggregateFunctions::StringConcatFinalize(FunctionContext* ctx, int sep_len = *reinterpret_cast(src.ptr); DCHECK(src.len >= header_len + sep_len); // Remove the header and the first separator. - StringVal result(ctx, src.len - header_len - sep_len); - memcpy(result.ptr, src.ptr + header_len + sep_len, result.len); + StringVal result = StringVal::CopyFrom(ctx, src.ptr + header_len + sep_len, + src.len - header_len - sep_len); ctx->Free(src.ptr); return result; } @@ -848,9 +835,7 @@ struct ReservoirSample { // Gets a copy of the sample value that allocates memory from ctx, if necessary. StringVal GetValue(FunctionContext* ctx) { - StringVal result = StringVal(ctx, len); - memcpy(result.ptr, &val[0], len); - return result; + return StringVal::CopyFrom(ctx, &val[0], len); } }; @@ -905,8 +890,7 @@ template const StringVal AggregateFunctions::ReservoirSampleSerialize(FunctionContext* ctx, const StringVal& src) { if (src.is_null) return src; - StringVal result(ctx, src.len); - memcpy(result.ptr, src.ptr, src.len); + StringVal result = StringVal::CopyFrom(ctx, src.ptr, src.len); ctx->Free(src.ptr); ReservoirSampleState* state = reinterpret_cast*>(result.ptr); @@ -1057,8 +1041,8 @@ StringVal AggregateFunctions::HistogramFinalize(FunctionContext* ctx, if (bucket_idx < (num_buckets - 1)) out << ", "; } const string& out_str = out.str(); - StringVal result_str(ctx, out_str.size()); - memcpy(result_str.ptr, out_str.c_str(), result_str.len); + StringVal result_str = StringVal::CopyFrom(ctx, + reinterpret_cast(out_str.c_str()), out_str.size()); ctx->Free(src_val.ptr); return result_str; } @@ -1378,7 +1362,7 @@ void AggregateFunctions::FirstValUpdate(FunctionContext* ctx, const StringVal& s return; } *dst = StringVal(ctx->Allocate(src.len), src.len); - memcpy(dst->ptr, src.ptr, src.len); + if (!dst->is_null) memcpy(dst->ptr, src.ptr, src.len); } template diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index 9e1fed270..1a92e6ee4 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -246,7 +246,9 @@ class ExprTest : public testing::Test { // We convert the expected result to string. case TYPE_FLOAT: case TYPE_DOUBLE: - expr_value_.string_val = value; + // Construct a StringValue from 'value'. 'value' must be valid for as long as + // this object is valid. + expr_value_.string_val = StringValue(value); return &expr_value_.string_val; case TYPE_TINYINT: expr_value_.tinyint_val = diff --git a/be/src/exprs/string-functions.cc b/be/src/exprs/string-functions.cc index 7f1aefc84..0c3725335 100644 --- a/be/src/exprs/string-functions.cc +++ b/be/src/exprs/string-functions.cc @@ -85,6 +85,7 @@ StringVal StringFunctions::Repeat( if (str.is_null || n.is_null) return StringVal::null(); if (str.len == 0 || n.val <= 0) return StringVal(); StringVal result(context, str.len * n.val); + if (UNLIKELY(result.is_null)) return result; uint8_t* ptr = result.ptr; for (int64_t i = 0; i < n.val; ++i) { memcpy(ptr, str.ptr, str.len); @@ -102,6 +103,7 @@ StringVal StringFunctions::Lpad(FunctionContext* context, const StringVal& str, if (len.val <= str.len || pad.len == 0) return StringVal(str.ptr, len.val); StringVal result(context, len.val); + if (result.is_null) return result; int padded_prefix_len = len.val - str.len; int pad_index = 0; int result_index = 0; @@ -129,6 +131,7 @@ StringVal StringFunctions::Rpad(FunctionContext* context, const StringVal& str, } StringVal result(context, len.val); + if (UNLIKELY(result.is_null)) return result; memcpy(result.ptr, str.ptr, str.len); // Append chars of pad until desired length @@ -157,6 +160,7 @@ IntVal StringFunctions::CharLength(FunctionContext* context, const StringVal& st StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::tolower(str.ptr[i]); } @@ -166,6 +170,7 @@ StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::toupper(str.ptr[i]); } @@ -179,6 +184,7 @@ StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; uint8_t* result_ptr = result.ptr; bool word_start = true; for (int i = 0; i < str.len; ++i) { @@ -196,6 +202,7 @@ StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& st StringVal StringFunctions::Reverse(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; std::reverse_copy(str.ptr, str.ptr + str.len, result.ptr); return result; } @@ -204,6 +211,7 @@ StringVal StringFunctions::Translate(FunctionContext* context, const StringVal& const StringVal& src, const StringVal& dst) { if (str.is_null || src.is_null || dst.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; // TODO: if we know src and dst are constant, we can prebuild a conversion // table to remove the inner loop. diff --git a/be/src/runtime/mem-pool.cc b/be/src/runtime/mem-pool.cc index 97a2a5981..48446daf8 100644 --- a/be/src/runtime/mem-pool.cc +++ b/be/src/runtime/mem-pool.cc @@ -45,7 +45,7 @@ MemPool::MemPool(MemTracker* mem_tracker, int chunk_size) DCHECK(mem_tracker != NULL); } -MemPool::ChunkInfo::ChunkInfo(int size) +MemPool::ChunkInfo::ChunkInfo(int64_t size) : owns_data(true), data(reinterpret_cast(malloc(size))), size(size), @@ -89,7 +89,7 @@ void MemPool::FreeAll() { } } -bool MemPool::FindChunk(int min_size, bool check_limits) { +bool MemPool::FindChunk(int64_t min_size, bool check_limits) { // Try to allocate from a free chunk. The first free chunk, if any, will be immediately // after the current chunk. int first_free_idx = current_chunk_idx_ + 1; @@ -111,7 +111,7 @@ bool MemPool::FindChunk(int min_size, bool check_limits) { if (current_chunk_idx_ == static_cast(chunks_.size())) { // need to allocate new chunk. - int chunk_size = chunk_size_; + int64_t chunk_size = chunk_size_; if (chunk_size == 0) { if (current_chunk_idx_ == 0) { chunk_size = DEFAULT_INITIAL_CHUNK_SIZE; diff --git a/be/src/runtime/mem-pool.h b/be/src/runtime/mem-pool.h index d36003487..77bb3e551 100644 --- a/be/src/runtime/mem-pool.h +++ b/be/src/runtime/mem-pool.h @@ -169,17 +169,17 @@ class MemPool { struct ChunkInfo { bool owns_data; // true if we eventually need to dealloc data uint8_t* data; - int size; // in bytes + int64_t size; // in bytes /// number of bytes allocated via Allocate() up to but excluding this chunk; /// *not* valid for chunks > current_chunk_idx_ (because that would create too /// much maintenance work if we have trailing unoccupied chunks) - int cumulative_allocated_bytes; + int64_t cumulative_allocated_bytes; /// bytes allocated via Allocate() in this chunk - int allocated_bytes; + int64_t allocated_bytes; - explicit ChunkInfo(int size); + explicit ChunkInfo(int64_t size); ChunkInfo() : owns_data(true), @@ -222,7 +222,7 @@ class MemPool { /// if a new chunk needs to be created. /// If check_limits is true, this call can fail (returns false) if adding a /// new chunk exceeds the mem limits. - bool FindChunk(int min_size, bool check_limits); + bool FindChunk(int64_t min_size, bool check_limits); /// Check integrity of the supporting data structures; always returns true but DCHECKs /// all invariants. diff --git a/be/src/runtime/string-value.h b/be/src/runtime/string-value.h index 11727be84..63f8c5426 100644 --- a/be/src/runtime/string-value.h +++ b/be/src/runtime/string-value.h @@ -31,6 +31,10 @@ namespace impala { /// shares its buffer the parent. /// TODO: rename this to be less confusing with impala_udf::StringVal. struct StringValue { + /// The current limitation for a string instance is 1GB character data. + /// See IMPALA-1619 for more details. + static const int MAX_LENGTH = (1 << 30); + /// TODO: change ptr to an offset relative to a contiguous memory block, /// so that we can send row batches between nodes without having to swizzle /// pointers @@ -39,21 +43,24 @@ struct StringValue { StringValue(char* ptr, int len): ptr(ptr), len(len) { DCHECK_GE(len, 0); + DCHECK_LE(len, MAX_LENGTH); } StringValue(): ptr(NULL), len(0) {} /// Construct a StringValue from 's'. 's' must be valid for as long as /// this object is valid. - StringValue(const std::string& s) + explicit StringValue(const std::string& s) : ptr(const_cast(s.c_str())), len(s.size()) { + DCHECK_LE(len, MAX_LENGTH); } /// Construct a StringValue from 's'. 's' must be valid for as long as /// this object is valid. /// s must be a null-terminated string. This constructor is to prevent /// accidental use of the version taking an std::string. - StringValue(const char* s) + explicit StringValue(const char* s) : ptr(const_cast(s)), len(strlen(s)) { + DCHECK_LE(len, MAX_LENGTH); } /// Byte-by-byte comparison. Returns: diff --git a/be/src/udf/udf.cc b/be/src/udf/udf.cc index 780f3a58a..571523a7e 100644 --- a/be/src/udf/udf.cc +++ b/be/src/udf/udf.cc @@ -102,6 +102,7 @@ class RuntimeState { #endif #include "common/names.h" +#include "common/compiler-util.h" using namespace impala; using namespace impala_udf; @@ -424,9 +425,64 @@ void FunctionContextImpl::SetConstantArgs(const vector& constant_args) // Note: this function crashes LLVM's JIT in expr-test if it's xcompiled. Do not move to // expr-ir.cc. This could probably use further investigation. StringVal::StringVal(FunctionContext* context, int len) - : len(len), ptr(context->impl()->AllocateLocal(len)) { + : len(len), ptr(NULL) { + if (UNLIKELY(len > StringVal::MAX_LENGTH)) { + std::cout << "MAX_LENGTH, Trying to allocate " << len; + context->SetError("String length larger than allowed limit of " + "1 GB character data."); + len = 0; + is_null = true; + } else { + ptr = context->impl()->AllocateLocal(len); + if (ptr == NULL && len > 0) { + len = 0; + is_null = true; + context->SetError("Large Memory allocation failed."); + } + } } +StringVal StringVal::CopyFrom(FunctionContext* ctx, const uint8_t* buf, size_t len) { + StringVal result(ctx, len); + if (!result.is_null) { + memcpy(result.ptr, buf, len); + } + return result; +} + +void StringVal::Append(FunctionContext* ctx, const uint8_t* buf, size_t buf_len) { + if (UNLIKELY(len + buf_len > StringVal::MAX_LENGTH)) { + ctx->SetError("Concatenated string length larger than allowed limit of " + "1 GB character data."); + ctx->Free(ptr); + ptr = NULL; + len = 0; + is_null = true; + } else { + ptr = ctx->Reallocate(ptr, len + buf_len); + memcpy(ptr + len, buf, buf_len); + len += buf_len; + } +} +void StringVal::Append(FunctionContext* ctx, const uint8_t* buf, size_t buf_len, + const uint8_t* buf2, size_t buf2_len) { + if (UNLIKELY(len + buf_len + buf2_len > StringVal::MAX_LENGTH)) { + ctx->SetError("Concatenated string length larger than allowed limit of " + "1 GB character data."); + ctx->Free(ptr); + ptr = NULL; + len = 0; + is_null = true; + } else { + ptr = ctx->Reallocate(ptr, len + buf_len + buf2_len); + memcpy(ptr + len, buf, buf_len); + memcpy(ptr + len + buf_len, buf2, buf2_len); + len += buf_len + buf2_len; + } +} + + + // TODO: why doesn't libudasample.so build if this in udf-ir.cc? const FunctionContext::TypeDesc* FunctionContext::GetArgType(int arg_idx) const { if (arg_idx < 0 || arg_idx >= impl_->arg_types_.size()) return NULL; diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index e6448c07e..b67723f2e 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -523,6 +523,9 @@ struct TimestampVal : public AnyVal { /// Note: there is a difference between a NULL string (is_null == true) and an /// empty string (len == 0). struct StringVal : public AnyVal { + + static const int MAX_LENGTH = (1 << 30); + int len; uint8_t* ptr; @@ -532,22 +535,37 @@ struct StringVal : public AnyVal { assert(len >= 0); }; - /// Construct a StringVal from NULL-terminated c-string. Note: this does not make a /// copy of ptr so the underlying string must exist as long as this StringVal does. StringVal(const char* ptr) : len(strlen(ptr)), ptr((uint8_t*)ptr) {} + /// Creates a StringVal, allocating a new buffer with 'len'. This should + /// be used to return StringVal objects in UDF/UDAs that need to allocate new + /// string memory. + /// + /// If the memory allocation fails, e.g. because the intermediate value would be too + /// large, the constructor will construct a NULL string and set an error on the function + /// context. + StringVal(FunctionContext* context, int len); + + /// Will create a new StringVal with the given dimension and copy the data from the + /// parameters. In case of an error will return a NULL string and set an error on the + /// function context. + static StringVal CopyFrom(FunctionContext* ctx, const uint8_t* buf, size_t len); + + /// Append the passed buffer to this StringVal. Reallocate memory to fit the buffer. If + /// the memory allocation becomes too large, will set an error on FunctionContext and + /// return a NULL string. + void Append(FunctionContext* ctx, const uint8_t* buf, size_t len); + void Append(FunctionContext* ctx, const uint8_t* buf, size_t len, const uint8_t* buf2, + size_t buf2_len); + static StringVal null() { StringVal sv; sv.is_null = true; return sv; } - /// Creates a StringVal, allocating a new buffer with 'len'. This should - /// be used to return StringVal objects in UDF/UDAs that need to allocate new - /// string memory. - StringVal(FunctionContext* context, int len); - bool operator==(const StringVal& other) const { if (is_null != other.is_null) return false; if (is_null) return true; diff --git a/buildall.sh b/buildall.sh index 664ce8500..6e1b3be2f 100755 --- a/buildall.sh +++ b/buildall.sh @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # run buildall.sh -help to see options ROOT=`dirname "$0"` diff --git a/testdata/workloads/functional-query/queries/QueryTest/large_strings.test b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test new file mode 100644 index 000000000..41b7916b6 --- /dev/null +++ b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test @@ -0,0 +1,56 @@ +==== +---- QUERY +-- IMPALA-1619 group_concat() error +select length(group_concat(l_comment, "!")) from (select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem) a; +---- CATCH +Concatenated string length larger than allowed limit of 1 GB character data. +===== \ No newline at end of file diff --git a/tests/query_test/test_queries.py b/tests/query_test/test_queries.py index f6732657a..9a81d2c5e 100644 --- a/tests/query_test/test_queries.py +++ b/tests/query_test/test_queries.py @@ -84,6 +84,10 @@ class TestQueries(ImpalaTestSuite): def test_union(self, vector): self.run_test_case('QueryTest/union', vector) + def test_very_large_strings(self, vector): + """Regression test for IMPALA-1619""" + self.run_test_case('QueryTest/large_strings', vector) + def test_sort(self, vector): if vector.get_value('table_format').file_format == 'hbase': pytest.xfail(reason="IMPALA-283 - select count(*) produces inconsistent results")