diff --git a/be/src/exprs/aggregate-functions-ir.cc b/be/src/exprs/aggregate-functions-ir.cc index 72a78b9b9..611945ad9 100644 --- a/be/src/exprs/aggregate-functions-ir.cc +++ b/be/src/exprs/aggregate-functions-ir.cc @@ -1291,10 +1291,8 @@ StringVal AggregateFunctions::ReservoirSampleFinalize(FunctionContext* ctx, if (i < (src_state->num_samples() - 1)) out << ", "; } const string& out_str = out.str(); - StringVal result_str(ctx, out_str.size()); - if (LIKELY(!result_str.is_null)) { - memcpy(result_str.ptr, out_str.c_str(), result_str.len); - } + StringVal result_str = StringVal::CopyFrom(ctx, + reinterpret_cast(out_str.c_str()), out_str.size()); src_state->Delete(ctx); return result_str; } diff --git a/be/src/exprs/hive-udf-call.cc b/be/src/exprs/hive-udf-call.cc index cd27dab0b..169b60318 100644 --- a/be/src/exprs/hive-udf-call.cc +++ b/be/src/exprs/hive-udf-call.cc @@ -323,15 +323,11 @@ DoubleVal HiveUdfCall::GetDoubleVal(ExprContext* ctx, const TupleRow* row) { StringVal HiveUdfCall::GetStringVal(ExprContext* ctx, const TupleRow* row) { DCHECK_EQ(type_.type, TYPE_STRING); StringVal result = *reinterpret_cast(Evaluate(ctx, row)); - // Copy the string into a local allocation with the usual lifetime for expr results. // Needed because the UDF output buffer is owned by the Java UDF executor and may be // freed or reused by the next call into the Java UDF executor. FunctionContext* fn_ctx = ctx->fn_context(fn_context_index_); - uint8_t* local_alloc = fn_ctx->impl()->AllocateLocal(result.len); - memcpy(local_alloc, result.ptr, result.len); - result.ptr = local_alloc; - return result; + return StringVal::CopyFrom(fn_ctx, result.ptr, result.len); } TimestampVal HiveUdfCall::GetTimestampVal(ExprContext* ctx, const TupleRow* row) { diff --git a/be/src/exprs/udf-builtins-ir.cc b/be/src/exprs/udf-builtins-ir.cc index 376a0d0ce..452c5b4d3 100644 --- a/be/src/exprs/udf-builtins-ir.cc +++ b/be/src/exprs/udf-builtins-ir.cc @@ -509,10 +509,8 @@ StringVal UdfBuiltins::PrintVector(FunctionContext* context, const StringVal& ar } ss << ">"; const string& str = ss.str(); - StringVal result(context, str.size()); - if (UNLIKELY(result.is_null)) return StringVal::null(); - memcpy(result.ptr, str.c_str(), str.size()); - return result; + return StringVal::CopyFrom(context, reinterpret_cast(str.c_str()), + str.size()); } DoubleVal UdfBuiltins::VectorGet(FunctionContext* context, const BigIntVal& index, diff --git a/be/src/udf/uda-test.cc b/be/src/udf/uda-test.cc index 3ca145f31..d589c75ec 100644 --- a/be/src/udf/uda-test.cc +++ b/be/src/udf/uda-test.cc @@ -139,8 +139,7 @@ void MinMerge(FunctionContext* context, const BufferVal& src, BufferVal* dst) { StringVal MinFinalize(FunctionContext* context, const BufferVal& val) { const MinState* state = reinterpret_cast(val); if (state->value == NULL) return StringVal::null(); - StringVal result = StringVal(context, state->len); - memcpy(result.ptr, state->value, state->len); + StringVal result = StringVal::CopyFrom(context, state->value, state->len); context->Free(state->value); return result; } diff --git a/be/src/udf/udf-test.cc b/be/src/udf/udf-test.cc index d35ca48dd..796bcb328 100644 --- a/be/src/udf/udf-test.cc +++ b/be/src/udf/udf-test.cc @@ -47,6 +47,7 @@ StringVal UpperUdf(FunctionContext* context, const StringVal& input) { if (input.is_null) return StringVal::null(); // Create a new StringVal object that's the same length as the input StringVal result = StringVal(context, input.len); + if (result.is_null) return StringVal::null(); for (int i = 0; i < input.len; ++i) { result.ptr[i] = toupper(input.ptr[i]); } diff --git a/be/src/udf_samples/uda-sample.cc b/be/src/udf_samples/uda-sample.cc index 23ce807a0..e8ff4fa7c 100644 --- a/be/src/udf_samples/uda-sample.cc +++ b/be/src/udf_samples/uda-sample.cc @@ -87,14 +87,15 @@ void StringConcatUpdate(FunctionContext* context, const StringVal& arg1, const StringVal& arg2, StringVal* val) { if (val->is_null) { val->is_null = false; - *val = StringVal(context, arg1.len); - memcpy(val->ptr, arg1.ptr, arg1.len); + *val = StringVal::CopyFrom(context, arg1.ptr, arg1.len); } else { int new_len = val->len + arg1.len + arg2.len; StringVal new_val(context, new_len); - memcpy(new_val.ptr, val->ptr, val->len); - memcpy(new_val.ptr + val->len, arg2.ptr, arg2.len); - memcpy(new_val.ptr + val->len + arg2.len, arg1.ptr, arg1.len); + if (!new_val.is_null) { + memcpy(new_val.ptr, val->ptr, val->len); + memcpy(new_val.ptr + val->len, arg2.ptr, arg2.len); + memcpy(new_val.ptr + val->len + arg2.len, arg1.ptr, arg1.len); + } *val = new_val; } } diff --git a/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-update.test b/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-update.test index c107ee814..d2b6f6aa8 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-update.test +++ b/testdata/workloads/functional-query/queries/QueryTest/alloc-fail-update.test @@ -50,3 +50,12 @@ select ndv(l_partkey), distinctpc(l_suppkey) from tpch.lineitem ---- CATCH failed to allocate ==== +---- QUERY +# IMPALA-5252: Verify HiveUdfCall allocations are checked. +create function replace_string(string) returns string +location '$FILESYSTEM_PREFIX/test-warehouse/impala-hive-udfs.jar' +symbol='org.apache.impala.ReplaceStringUdf'; +select replace_string(l_comment) from tpch.lineitem limit 10; +---- CATCH +failed to allocate +==== diff --git a/tests/custom_cluster/test_alloc_fail.py b/tests/custom_cluster/test_alloc_fail.py index 28cdc191a..f1ae3dc4f 100644 --- a/tests/custom_cluster/test_alloc_fail.py +++ b/tests/custom_cluster/test_alloc_fail.py @@ -33,14 +33,14 @@ class TestAllocFail(CustomClusterTestSuite): def test_alloc_fail_init(self, vector): self.run_test_case('QueryTest/alloc-fail-init', vector) + # TODO: Rewrite or remove the non-deterministic test. + @pytest.mark.xfail(run=True, reason="IMPALA-2925: the execution is not deterministic " + "so some tests sometimes don't fail as expected") @pytest.mark.execute_serially @CustomClusterTestSuite.with_args("--stress_free_pool_alloc=3") - def test_alloc_fail_update(self, vector): - # TODO: Rewrite or remove the non-deterministic test. - pytest.xfail("IMPALA-2925: the execution is not deterministic so some " - "tests sometimes don't fail as expected") + def test_alloc_fail_update(self, vector, unique_database): # Note that this test relies on pre-aggregation to exercise the Serialize() path so # query option 'num_nodes' must not be 1. CustomClusterTestSuite.add_test_dimensions() # already filters out vectors with 'num_nodes' != 0 so just assert it here. assert vector.get_value('exec_option')['num_nodes'] == 0 - self.run_test_case('QueryTest/alloc-fail-update', vector) + self.run_test_case('QueryTest/alloc-fail-update', vector, use_db=unique_database)