From 579be1c5422ea2befcfda814c1cbe3cd5eaf391c Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 25 Aug 2015 11:39:43 -0700 Subject: [PATCH] IMPALA-2284: Disallow long (1<<30) strings in group_concat() This is the first step to fix issues with large memory allocations. In this patch, the built-in `group_concat` is no longer allowed to allocate arbitraryly large strings and crash impala, but is limited to the upper bound of possible allocations in Impala. This patch does not perform any functional change, but rather avoids unnecessary crashes. However, it changes the parameter type of FindChunk() in MemPool to be a signed 64bit integer. This change allows the mempool to allocate internally memory of more than one 1GB, but the public interface of Allocate() is not changed, so the general limitation remains. The reason for this change is as follows: 1) In a UDF FunctionContext::Reallocate() would allocate slightly more than 512MB from the FreePool. 2) The free pool tries to double this size to alloocate 1GB from the MemPool. 3) The MemPool doubles the size again and overflows the signed 32bit integer in the FindChunk() method. This will then only allocate 1GB instead of the expected 2GB. What happens is that one of the callers expected a larger allocation than actually happened, which will in turn lead to memory corruption as soon as the memory is accessed. Change-Id: I068835dfa0ac8f7538253d9fa5cfc3fb9d352f6a Reviewed-on: http://gerrit.cloudera.org:8080/858 Tested-by: Internal Jenkins Reviewed-by: Dan Hecht --- be/src/exprs/aggregate-functions.cc | 40 ++++--------- be/src/exprs/expr-test.cc | 4 +- be/src/exprs/string-functions.cc | 8 +++ be/src/runtime/mem-pool.cc | 6 +- be/src/runtime/mem-pool.h | 10 ++-- be/src/runtime/string-value.h | 11 +++- be/src/udf/udf.cc | 58 ++++++++++++++++++- be/src/udf/udf.h | 30 ++++++++-- buildall.sh | 1 - .../queries/QueryTest/large_strings.test | 56 ++++++++++++++++++ tests/query_test/test_queries.py | 4 ++ 11 files changed, 181 insertions(+), 47 deletions(-) create mode 100644 testdata/workloads/functional-query/queries/QueryTest/large_strings.test diff --git a/be/src/exprs/aggregate-functions.cc b/be/src/exprs/aggregate-functions.cc index ce4b07294..18d824494 100644 --- a/be/src/exprs/aggregate-functions.cc +++ b/be/src/exprs/aggregate-functions.cc @@ -127,9 +127,7 @@ StringVal ToStringVal(FunctionContext* context, T val) { stringstream ss; ss << val; const string &str = ss.str(); - StringVal string_val(context, str.size()); - memcpy(string_val.ptr, str.c_str(), str.size()); - return string_val; + return StringVal::CopyFrom(context, reinterpret_cast(str.c_str()), str.size()); } // Delimiter to use if the separator is NULL. @@ -159,9 +157,7 @@ void AggregateFunctions::InitZero(FunctionContext*, DecimalVal* dst) { StringVal AggregateFunctions::StringValGetValue( FunctionContext* ctx, const StringVal& src) { if (src.is_null) return src; - StringVal result(ctx, src.len); - memcpy(result.ptr, src.ptr, src.len); - return result; + return StringVal::CopyFrom(ctx, src.ptr, src.len); } StringVal AggregateFunctions::StringValSerializeOrFinalize( @@ -584,13 +580,7 @@ void AggregateFunctions::StringConcatUpdate(FunctionContext* ctx, *result = StringVal(ctx->Allocate(header_len), header_len); *reinterpret_cast(result->ptr) = sep->len; } - int new_len = result->len + sep->len + src.len; - result->ptr = ctx->Reallocate(result->ptr, new_len); - memcpy(result->ptr + result->len, sep->ptr, sep->len); - result->len += sep->len; - memcpy(result->ptr + result->len, src.ptr, src.len); - result->len += src.len; - DCHECK(result->len == new_len); + result->Append(ctx, sep->ptr, sep->len, src.ptr, src.len); } void AggregateFunctions::StringConcatMerge(FunctionContext* ctx, @@ -600,15 +590,12 @@ void AggregateFunctions::StringConcatMerge(FunctionContext* ctx, if (result->is_null) { // Copy the header from the first intermediate value. *result = StringVal(ctx->Allocate(header_len), header_len); + if (result->is_null) return; *reinterpret_cast(result->ptr) = *reinterpret_cast(src.ptr); } // Append the string portion of the intermediate src to result (omit src's header). - int new_len = result->len + src.len - header_len; - result->ptr = ctx->Reallocate(result->ptr, new_len); - memcpy(result->ptr + result->len, src.ptr + header_len, src.len - header_len); - result->len += src.len - header_len; - DCHECK(result->len == new_len); + result->Append(ctx, src.ptr + header_len, src.len - header_len); } StringVal AggregateFunctions::StringConcatFinalize(FunctionContext* ctx, @@ -619,8 +606,8 @@ StringVal AggregateFunctions::StringConcatFinalize(FunctionContext* ctx, int sep_len = *reinterpret_cast(src.ptr); DCHECK(src.len >= header_len + sep_len); // Remove the header and the first separator. - StringVal result(ctx, src.len - header_len - sep_len); - memcpy(result.ptr, src.ptr + header_len + sep_len, result.len); + StringVal result = StringVal::CopyFrom(ctx, src.ptr + header_len + sep_len, + src.len - header_len - sep_len); ctx->Free(src.ptr); return result; } @@ -848,9 +835,7 @@ struct ReservoirSample { // Gets a copy of the sample value that allocates memory from ctx, if necessary. StringVal GetValue(FunctionContext* ctx) { - StringVal result = StringVal(ctx, len); - memcpy(result.ptr, &val[0], len); - return result; + return StringVal::CopyFrom(ctx, &val[0], len); } }; @@ -905,8 +890,7 @@ template const StringVal AggregateFunctions::ReservoirSampleSerialize(FunctionContext* ctx, const StringVal& src) { if (src.is_null) return src; - StringVal result(ctx, src.len); - memcpy(result.ptr, src.ptr, src.len); + StringVal result = StringVal::CopyFrom(ctx, src.ptr, src.len); ctx->Free(src.ptr); ReservoirSampleState* state = reinterpret_cast*>(result.ptr); @@ -1057,8 +1041,8 @@ StringVal AggregateFunctions::HistogramFinalize(FunctionContext* ctx, if (bucket_idx < (num_buckets - 1)) out << ", "; } const string& out_str = out.str(); - StringVal result_str(ctx, out_str.size()); - memcpy(result_str.ptr, out_str.c_str(), result_str.len); + StringVal result_str = StringVal::CopyFrom(ctx, + reinterpret_cast(out_str.c_str()), out_str.size()); ctx->Free(src_val.ptr); return result_str; } @@ -1378,7 +1362,7 @@ void AggregateFunctions::FirstValUpdate(FunctionContext* ctx, const StringVal& s return; } *dst = StringVal(ctx->Allocate(src.len), src.len); - memcpy(dst->ptr, src.ptr, src.len); + if (!dst->is_null) memcpy(dst->ptr, src.ptr, src.len); } template diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index 9e1fed270..1a92e6ee4 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -246,7 +246,9 @@ class ExprTest : public testing::Test { // We convert the expected result to string. case TYPE_FLOAT: case TYPE_DOUBLE: - expr_value_.string_val = value; + // Construct a StringValue from 'value'. 'value' must be valid for as long as + // this object is valid. + expr_value_.string_val = StringValue(value); return &expr_value_.string_val; case TYPE_TINYINT: expr_value_.tinyint_val = diff --git a/be/src/exprs/string-functions.cc b/be/src/exprs/string-functions.cc index 7f1aefc84..0c3725335 100644 --- a/be/src/exprs/string-functions.cc +++ b/be/src/exprs/string-functions.cc @@ -85,6 +85,7 @@ StringVal StringFunctions::Repeat( if (str.is_null || n.is_null) return StringVal::null(); if (str.len == 0 || n.val <= 0) return StringVal(); StringVal result(context, str.len * n.val); + if (UNLIKELY(result.is_null)) return result; uint8_t* ptr = result.ptr; for (int64_t i = 0; i < n.val; ++i) { memcpy(ptr, str.ptr, str.len); @@ -102,6 +103,7 @@ StringVal StringFunctions::Lpad(FunctionContext* context, const StringVal& str, if (len.val <= str.len || pad.len == 0) return StringVal(str.ptr, len.val); StringVal result(context, len.val); + if (result.is_null) return result; int padded_prefix_len = len.val - str.len; int pad_index = 0; int result_index = 0; @@ -129,6 +131,7 @@ StringVal StringFunctions::Rpad(FunctionContext* context, const StringVal& str, } StringVal result(context, len.val); + if (UNLIKELY(result.is_null)) return result; memcpy(result.ptr, str.ptr, str.len); // Append chars of pad until desired length @@ -157,6 +160,7 @@ IntVal StringFunctions::CharLength(FunctionContext* context, const StringVal& st StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::tolower(str.ptr[i]); } @@ -166,6 +170,7 @@ StringVal StringFunctions::Lower(FunctionContext* context, const StringVal& str) StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; for (int i = 0; i < str.len; ++i) { result.ptr[i] = ::toupper(str.ptr[i]); } @@ -179,6 +184,7 @@ StringVal StringFunctions::Upper(FunctionContext* context, const StringVal& str) StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; uint8_t* result_ptr = result.ptr; bool word_start = true; for (int i = 0; i < str.len; ++i) { @@ -196,6 +202,7 @@ StringVal StringFunctions::InitCap(FunctionContext* context, const StringVal& st StringVal StringFunctions::Reverse(FunctionContext* context, const StringVal& str) { if (str.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; std::reverse_copy(str.ptr, str.ptr + str.len, result.ptr); return result; } @@ -204,6 +211,7 @@ StringVal StringFunctions::Translate(FunctionContext* context, const StringVal& const StringVal& src, const StringVal& dst) { if (str.is_null || src.is_null || dst.is_null) return StringVal::null(); StringVal result(context, str.len); + if (UNLIKELY(result.is_null)) return result; // TODO: if we know src and dst are constant, we can prebuild a conversion // table to remove the inner loop. diff --git a/be/src/runtime/mem-pool.cc b/be/src/runtime/mem-pool.cc index 97a2a5981..48446daf8 100644 --- a/be/src/runtime/mem-pool.cc +++ b/be/src/runtime/mem-pool.cc @@ -45,7 +45,7 @@ MemPool::MemPool(MemTracker* mem_tracker, int chunk_size) DCHECK(mem_tracker != NULL); } -MemPool::ChunkInfo::ChunkInfo(int size) +MemPool::ChunkInfo::ChunkInfo(int64_t size) : owns_data(true), data(reinterpret_cast(malloc(size))), size(size), @@ -89,7 +89,7 @@ void MemPool::FreeAll() { } } -bool MemPool::FindChunk(int min_size, bool check_limits) { +bool MemPool::FindChunk(int64_t min_size, bool check_limits) { // Try to allocate from a free chunk. The first free chunk, if any, will be immediately // after the current chunk. int first_free_idx = current_chunk_idx_ + 1; @@ -111,7 +111,7 @@ bool MemPool::FindChunk(int min_size, bool check_limits) { if (current_chunk_idx_ == static_cast(chunks_.size())) { // need to allocate new chunk. - int chunk_size = chunk_size_; + int64_t chunk_size = chunk_size_; if (chunk_size == 0) { if (current_chunk_idx_ == 0) { chunk_size = DEFAULT_INITIAL_CHUNK_SIZE; diff --git a/be/src/runtime/mem-pool.h b/be/src/runtime/mem-pool.h index d36003487..77bb3e551 100644 --- a/be/src/runtime/mem-pool.h +++ b/be/src/runtime/mem-pool.h @@ -169,17 +169,17 @@ class MemPool { struct ChunkInfo { bool owns_data; // true if we eventually need to dealloc data uint8_t* data; - int size; // in bytes + int64_t size; // in bytes /// number of bytes allocated via Allocate() up to but excluding this chunk; /// *not* valid for chunks > current_chunk_idx_ (because that would create too /// much maintenance work if we have trailing unoccupied chunks) - int cumulative_allocated_bytes; + int64_t cumulative_allocated_bytes; /// bytes allocated via Allocate() in this chunk - int allocated_bytes; + int64_t allocated_bytes; - explicit ChunkInfo(int size); + explicit ChunkInfo(int64_t size); ChunkInfo() : owns_data(true), @@ -222,7 +222,7 @@ class MemPool { /// if a new chunk needs to be created. /// If check_limits is true, this call can fail (returns false) if adding a /// new chunk exceeds the mem limits. - bool FindChunk(int min_size, bool check_limits); + bool FindChunk(int64_t min_size, bool check_limits); /// Check integrity of the supporting data structures; always returns true but DCHECKs /// all invariants. diff --git a/be/src/runtime/string-value.h b/be/src/runtime/string-value.h index 11727be84..63f8c5426 100644 --- a/be/src/runtime/string-value.h +++ b/be/src/runtime/string-value.h @@ -31,6 +31,10 @@ namespace impala { /// shares its buffer the parent. /// TODO: rename this to be less confusing with impala_udf::StringVal. struct StringValue { + /// The current limitation for a string instance is 1GB character data. + /// See IMPALA-1619 for more details. + static const int MAX_LENGTH = (1 << 30); + /// TODO: change ptr to an offset relative to a contiguous memory block, /// so that we can send row batches between nodes without having to swizzle /// pointers @@ -39,21 +43,24 @@ struct StringValue { StringValue(char* ptr, int len): ptr(ptr), len(len) { DCHECK_GE(len, 0); + DCHECK_LE(len, MAX_LENGTH); } StringValue(): ptr(NULL), len(0) {} /// Construct a StringValue from 's'. 's' must be valid for as long as /// this object is valid. - StringValue(const std::string& s) + explicit StringValue(const std::string& s) : ptr(const_cast(s.c_str())), len(s.size()) { + DCHECK_LE(len, MAX_LENGTH); } /// Construct a StringValue from 's'. 's' must be valid for as long as /// this object is valid. /// s must be a null-terminated string. This constructor is to prevent /// accidental use of the version taking an std::string. - StringValue(const char* s) + explicit StringValue(const char* s) : ptr(const_cast(s)), len(strlen(s)) { + DCHECK_LE(len, MAX_LENGTH); } /// Byte-by-byte comparison. Returns: diff --git a/be/src/udf/udf.cc b/be/src/udf/udf.cc index 780f3a58a..571523a7e 100644 --- a/be/src/udf/udf.cc +++ b/be/src/udf/udf.cc @@ -102,6 +102,7 @@ class RuntimeState { #endif #include "common/names.h" +#include "common/compiler-util.h" using namespace impala; using namespace impala_udf; @@ -424,9 +425,64 @@ void FunctionContextImpl::SetConstantArgs(const vector& constant_args) // Note: this function crashes LLVM's JIT in expr-test if it's xcompiled. Do not move to // expr-ir.cc. This could probably use further investigation. StringVal::StringVal(FunctionContext* context, int len) - : len(len), ptr(context->impl()->AllocateLocal(len)) { + : len(len), ptr(NULL) { + if (UNLIKELY(len > StringVal::MAX_LENGTH)) { + std::cout << "MAX_LENGTH, Trying to allocate " << len; + context->SetError("String length larger than allowed limit of " + "1 GB character data."); + len = 0; + is_null = true; + } else { + ptr = context->impl()->AllocateLocal(len); + if (ptr == NULL && len > 0) { + len = 0; + is_null = true; + context->SetError("Large Memory allocation failed."); + } + } } +StringVal StringVal::CopyFrom(FunctionContext* ctx, const uint8_t* buf, size_t len) { + StringVal result(ctx, len); + if (!result.is_null) { + memcpy(result.ptr, buf, len); + } + return result; +} + +void StringVal::Append(FunctionContext* ctx, const uint8_t* buf, size_t buf_len) { + if (UNLIKELY(len + buf_len > StringVal::MAX_LENGTH)) { + ctx->SetError("Concatenated string length larger than allowed limit of " + "1 GB character data."); + ctx->Free(ptr); + ptr = NULL; + len = 0; + is_null = true; + } else { + ptr = ctx->Reallocate(ptr, len + buf_len); + memcpy(ptr + len, buf, buf_len); + len += buf_len; + } +} +void StringVal::Append(FunctionContext* ctx, const uint8_t* buf, size_t buf_len, + const uint8_t* buf2, size_t buf2_len) { + if (UNLIKELY(len + buf_len + buf2_len > StringVal::MAX_LENGTH)) { + ctx->SetError("Concatenated string length larger than allowed limit of " + "1 GB character data."); + ctx->Free(ptr); + ptr = NULL; + len = 0; + is_null = true; + } else { + ptr = ctx->Reallocate(ptr, len + buf_len + buf2_len); + memcpy(ptr + len, buf, buf_len); + memcpy(ptr + len + buf_len, buf2, buf2_len); + len += buf_len + buf2_len; + } +} + + + // TODO: why doesn't libudasample.so build if this in udf-ir.cc? const FunctionContext::TypeDesc* FunctionContext::GetArgType(int arg_idx) const { if (arg_idx < 0 || arg_idx >= impl_->arg_types_.size()) return NULL; diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index e6448c07e..b67723f2e 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -523,6 +523,9 @@ struct TimestampVal : public AnyVal { /// Note: there is a difference between a NULL string (is_null == true) and an /// empty string (len == 0). struct StringVal : public AnyVal { + + static const int MAX_LENGTH = (1 << 30); + int len; uint8_t* ptr; @@ -532,22 +535,37 @@ struct StringVal : public AnyVal { assert(len >= 0); }; - /// Construct a StringVal from NULL-terminated c-string. Note: this does not make a /// copy of ptr so the underlying string must exist as long as this StringVal does. StringVal(const char* ptr) : len(strlen(ptr)), ptr((uint8_t*)ptr) {} + /// Creates a StringVal, allocating a new buffer with 'len'. This should + /// be used to return StringVal objects in UDF/UDAs that need to allocate new + /// string memory. + /// + /// If the memory allocation fails, e.g. because the intermediate value would be too + /// large, the constructor will construct a NULL string and set an error on the function + /// context. + StringVal(FunctionContext* context, int len); + + /// Will create a new StringVal with the given dimension and copy the data from the + /// parameters. In case of an error will return a NULL string and set an error on the + /// function context. + static StringVal CopyFrom(FunctionContext* ctx, const uint8_t* buf, size_t len); + + /// Append the passed buffer to this StringVal. Reallocate memory to fit the buffer. If + /// the memory allocation becomes too large, will set an error on FunctionContext and + /// return a NULL string. + void Append(FunctionContext* ctx, const uint8_t* buf, size_t len); + void Append(FunctionContext* ctx, const uint8_t* buf, size_t len, const uint8_t* buf2, + size_t buf2_len); + static StringVal null() { StringVal sv; sv.is_null = true; return sv; } - /// Creates a StringVal, allocating a new buffer with 'len'. This should - /// be used to return StringVal objects in UDF/UDAs that need to allocate new - /// string memory. - StringVal(FunctionContext* context, int len); - bool operator==(const StringVal& other) const { if (is_null != other.is_null) return false; if (is_null) return true; diff --git a/buildall.sh b/buildall.sh index 664ce8500..6e1b3be2f 100755 --- a/buildall.sh +++ b/buildall.sh @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # run buildall.sh -help to see options ROOT=`dirname "$0"` diff --git a/testdata/workloads/functional-query/queries/QueryTest/large_strings.test b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test new file mode 100644 index 000000000..41b7916b6 --- /dev/null +++ b/testdata/workloads/functional-query/queries/QueryTest/large_strings.test @@ -0,0 +1,56 @@ +==== +---- QUERY +-- IMPALA-1619 group_concat() error +select length(group_concat(l_comment, "!")) from (select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem union all +select l_comment from tpch_parquet.lineitem) a; +---- CATCH +Concatenated string length larger than allowed limit of 1 GB character data. +===== \ No newline at end of file diff --git a/tests/query_test/test_queries.py b/tests/query_test/test_queries.py index f6732657a..9a81d2c5e 100644 --- a/tests/query_test/test_queries.py +++ b/tests/query_test/test_queries.py @@ -84,6 +84,10 @@ class TestQueries(ImpalaTestSuite): def test_union(self, vector): self.run_test_case('QueryTest/union', vector) + def test_very_large_strings(self, vector): + """Regression test for IMPALA-1619""" + self.run_test_case('QueryTest/large_strings', vector) + def test_sort(self, vector): if vector.get_value('table_format').file_format == 'hbase': pytest.xfail(reason="IMPALA-283 - select count(*) produces inconsistent results")