diff --git a/be/src/exprs/scalar-fn-call.cc b/be/src/exprs/scalar-fn-call.cc index e51e20a77..3d92cdc52 100644 --- a/be/src/exprs/scalar-fn-call.cc +++ b/be/src/exprs/scalar-fn-call.cc @@ -83,15 +83,18 @@ Status ScalarFnCall::Prepare(RuntimeState* state, const RowDescriptor& desc, context_index_ = context->Register(state, return_type, arg_types, varargs_buffer_size); // If the codegen object hasn't been created yet and we're calling a builtin or native - // UDF with <= 3 non-variadic arguments, we can use the interpreted path and call the + // UDF with <= 8 non-variadic arguments, we can use the interpreted path and call the // builtin without codegen. This saves us the overhead of creating the codegen object // when it's not necessary (i.e., in plan fragments with no codegen-enabled operators). + // In addition, we can never codegen char arguments. // TODO: codegen for char arguments - if (char_arg || (!state->codegen_created() && NumFixedArgs() <= 3 && + if (char_arg || (!state->codegen_created() && NumFixedArgs() <= 8 && (fn_.binary_type == TFunctionBinaryType::BUILTIN || fn_.binary_type == TFunctionBinaryType::NATIVE))) { + // Builtins with char arguments must still have <= 8 arguments. + // TODO: delete when we have codegen for char arguments if (char_arg) { - DCHECK(NumFixedArgs() <= 3 && fn_.binary_type == TFunctionBinaryType::BUILTIN); + DCHECK(NumFixedArgs() <= 8 && fn_.binary_type == TFunctionBinaryType::BUILTIN); } Status status = LibCache::instance()->GetSoFunctionPtr( fn_.hdfs_location, fn_.scalar_fn.symbol, &scalar_fn_, &cache_entry_); @@ -534,6 +537,38 @@ RETURN_TYPE ScalarFnCall::InterpretEval(ExprContext* context, TupleRow* row) { const AnyVal& a2, const AnyVal& a3); return reinterpret_cast(scalar_fn_)(fn_ctx, *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2]); + case 4: + typedef RETURN_TYPE (*ScalarFn4)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3]); + case 5: + typedef RETURN_TYPE (*ScalarFn5)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4]); + case 6: + typedef RETURN_TYPE (*ScalarFn6)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5]); + case 7: + typedef RETURN_TYPE (*ScalarFn7)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6, const AnyVal& a7); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5], *(*input_vals)[6]); + case 8: + typedef RETURN_TYPE (*ScalarFn8)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6, const AnyVal& a7, const AnyVal& a8); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5], *(*input_vals)[6], *(*input_vals)[7]); default: DCHECK(false) << "Interpreted path not implemented. We should have " << "codegen'd the wrapper"; @@ -561,6 +596,43 @@ RETURN_TYPE ScalarFnCall::InterpretEval(ExprContext* context, TupleRow* row) { const AnyVal& a2, const AnyVal& a3, int num_varargs, const AnyVal* varargs); return reinterpret_cast(scalar_fn_)(fn_ctx, *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], num_varargs, varargs); + case 4: + typedef RETURN_TYPE (*VarargFn4)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, int num_varargs, + const AnyVal* varargs); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + num_varargs, varargs); + case 5: + typedef RETURN_TYPE (*VarargFn5)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + int num_varargs, const AnyVal* varargs); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], num_varargs, varargs); + case 6: + typedef RETURN_TYPE (*VarargFn6)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6, int num_varargs, const AnyVal* varargs); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5], num_varargs, varargs); + case 7: + typedef RETURN_TYPE (*VarargFn7)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6, const AnyVal& a7, int num_varargs, const AnyVal* varargs); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5], *(*input_vals)[6], num_varargs, varargs); + case 8: + typedef RETURN_TYPE (*VarargFn8)(FunctionContext*, const AnyVal& a1, + const AnyVal& a2, const AnyVal& a3, const AnyVal& a4, const AnyVal& a5, + const AnyVal& a6, const AnyVal& a7, const AnyVal& a8, int num_varargs, + const AnyVal* varargs); + return reinterpret_cast(scalar_fn_)(fn_ctx, + *(*input_vals)[0], *(*input_vals)[1], *(*input_vals)[2], *(*input_vals)[3], + *(*input_vals)[4], *(*input_vals)[5], *(*input_vals)[6], *(*input_vals)[7], + num_varargs, varargs); default: DCHECK(false) << "Interpreted path not implemented. We should have " << "codegen'd the wrapper"; diff --git a/be/src/testutil/test-udfs.cc b/be/src/testutil/test-udfs.cc index b06d6b000..cf2a5e6f0 100644 --- a/be/src/testutil/test-udfs.cc +++ b/be/src/testutil/test-udfs.cc @@ -285,3 +285,31 @@ BigIntVal DoubleFreeTest(FunctionContext* context, BigIntVal bytes) { extern "C" BigIntVal UnmangledSymbol(FunctionContext* context) { return BigIntVal(5); } + +// Functions to test interpreted path +IntVal FourArgs(FunctionContext* context, const IntVal& v1, const IntVal& v2, + const IntVal& v3, const IntVal& v4) { + return IntVal(v1.val + v2.val + v3.val + v4.val); +} + +IntVal FiveArgs(FunctionContext* context, const IntVal& v1, const IntVal& v2, + const IntVal& v3, const IntVal& v4, const IntVal& v5) { + return IntVal(v1.val + v2.val + v3.val + v4.val + v5.val); +} + +IntVal SixArgs(FunctionContext* context, const IntVal& v1, const IntVal& v2, + const IntVal& v3, const IntVal& v4, const IntVal& v5, const IntVal& v6) { + return IntVal(v1.val + v2.val + v3.val + v4.val + v5.val + v6.val); +} + +IntVal SevenArgs(FunctionContext* context, const IntVal& v1, const IntVal& v2, + const IntVal& v3, const IntVal& v4, const IntVal& v5, const IntVal& v6, + const IntVal& v7) { + return IntVal(v1.val + v2.val + v3.val + v4.val + v5.val + v6.val + v7.val); +} + +IntVal EightArgs(FunctionContext* context, const IntVal& v1, const IntVal& v2, + const IntVal& v3, const IntVal& v4, const IntVal& v5, const IntVal& v6, + const IntVal& v7, const IntVal& v8) { + return IntVal(v1.val + v2.val + v3.val + v4.val + v5.val + v6.val + v7.val + v8.val); +} diff --git a/testdata/workloads/functional-query/queries/QueryTest/udf.test b/testdata/workloads/functional-query/queries/QueryTest/udf.test index d2f2d8240..bdd5ad0af 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/udf.test +++ b/testdata/workloads/functional-query/queries/QueryTest/udf.test @@ -480,3 +480,38 @@ INT ---- RESULTS NULL ==== +---- QUERY +select four_args(1,2,3,4); +---- TYPES +INT +---- RESULTS +10 +==== +---- QUERY +select five_args(1,2,3,4,5); +---- TYPES +INT +---- RESULTS +15 +==== +---- QUERY +select six_args(1,2,3,4,5,6); +---- TYPES +INT +---- RESULTS +21 +==== +---- QUERY +select seven_args(1,2,3,4,5,6,7); +---- TYPES +INT +---- RESULTS +28 +==== +---- QUERY +select eight_args(1,2,3,4,5,6,7,8); +---- TYPES +INT +---- RESULTS +36 +==== diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py index 6cdcfc0a9..baed57564 100644 --- a/tests/query_test/test_udfs.py +++ b/tests/query_test/test_udfs.py @@ -248,6 +248,11 @@ drop function if exists {database}.validate_open(int); drop function if exists {database}.mem_test(bigint); drop function if exists {database}.mem_test_leaks(bigint); drop function if exists {database}.unmangled_symbol(); +drop function if exists {database}.four_args(int, int, int, int); +drop function if exists {database}.five_args(int, int, int, int, int); +drop function if exists {database}.six_args(int, int, int, int, int, int); +drop function if exists {database}.seven_args(int, int, int, int, int, int, int); +drop function if exists {database}.eight_args(int, int, int, int, int, int, int, int); create database if not exists {database}; @@ -347,4 +352,19 @@ prepare_fn='MemTestPrepare'; -- Regression test for IMPALA-1475 create function {database}.unmangled_symbol() returns bigint location '{location}' symbol='UnmangledSymbol'; + +create function {database}.four_args(int, int, int, int) returns int +location '{location}' symbol='FourArgs'; + +create function {database}.five_args(int, int, int, int, int) returns int +location '{location}' symbol='FiveArgs'; + +create function {database}.six_args(int, int, int, int, int, int) returns int +location '{location}' symbol='SixArgs'; + +create function {database}.seven_args(int, int, int, int, int, int, int) returns int +location '{location}' symbol='SevenArgs'; + +create function {database}.eight_args(int, int, int, int, int, int, int, int) returns int +location '{location}' symbol='EightArgs'; """