diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc index 549f2f0cf..806a97183 100644 --- a/be/src/testutil/test-udas.cc +++ b/be/src/testutil/test-udas.cc @@ -57,36 +57,30 @@ StringVal AggFinalize(FunctionContext*, const StringVal& v) { // Defines AggIntermediate(int) returns BIGINT intermediate STRING void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {} -void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { +static void ValidateFunctionContext(const FunctionContext* context) { assert(context->GetNumArgs() == 1); assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); } +void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { + ValidateFunctionContext(context); +} void AggIntermediateInit(FunctionContext* context, StringVal*) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); } void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); } BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); return BigIntVal::null(); } // Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6) // intermediate DECIMAL(3,4) // Useful to test that type parameters are plumbed through. -void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) { +static void ValidateFunctionContext2(const FunctionContext* context) { assert(context->GetNumArgs() == 2); assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); assert(context->GetArgType(0)->precision == 2); @@ -99,47 +93,52 @@ void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, c assert(context->GetReturnType().precision == 6); assert(context->GetReturnType().scale == 5); } +void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, + const IntVal&, DecimalVal*) { + ValidateFunctionContext2(context); +} void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); + ValidateFunctionContext2(context); } -void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); +void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, + DecimalVal*) { + ValidateFunctionContext2(context); } DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); + ValidateFunctionContext2(context); return DecimalVal::null(); } +// Defines AggStringIntermediate(DECIMAL(20,10), BIGINT, STRING) returns DECIMAL(20,0) +// intermediate STRING. +// Useful to test decimal input types with string as intermediate types. +static void ValidateFunctionContext3(const FunctionContext* context) { + assert(context->GetNumArgs() == 3); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 20); + assert(context->GetArgType(0)->scale == 10); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_BIGINT); + assert(context->GetArgType(2)->type == FunctionContext::TYPE_STRING); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 20); + assert(context->GetReturnType().scale == 0); +} +void AggStringIntermediateUpdate(FunctionContext* context, const DecimalVal&, + const BigIntVal&, const StringVal&, StringVal*) { + ValidateFunctionContext3(context); +} +void AggStringIntermediateInit(FunctionContext* context, StringVal*) { + ValidateFunctionContext3(context); +} +void AggStringIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { + ValidateFunctionContext3(context); +} +DecimalVal AggStringIntermediateFinalize(FunctionContext* context, const StringVal&) { + ValidateFunctionContext3(context); + return DecimalVal(100); +} + // Defines MemTest(bigint) return bigint // "Allocates" the specified number of bytes in the update function and frees them in the // serialize function. Useful for testing mem limits. diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java index c9d098db6..1e06254c3 100644 --- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java +++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java @@ -102,8 +102,12 @@ public class FunctionCallExpr extends Expr { FunctionCallExpr agg, List params) { Preconditions.checkState(agg.isAnalyzed()); Preconditions.checkState(agg.isAggregateFunction()); + // If the input aggregate function is already a merge aggregate function (due to + // 2-phase aggregation), its input types will be the intermediate value types. The + // original input argument exprs are in 'agg.mergeAggInputFn_' so use it instead. + FunctionCallExpr mergeAggInputFn = agg.isMergeAggFn() ? agg.mergeAggInputFn_ : agg; FunctionCallExpr result = new FunctionCallExpr( - agg.fnName_, new FunctionParams(false, params), agg); + agg.fnName_, new FunctionParams(false, params), mergeAggInputFn); // Inherit the function object from 'agg'. result.fn_ = agg.fn_; result.type_ = agg.type_; @@ -127,8 +131,8 @@ public class FunctionCallExpr extends Expr { fnName_ = other.fnName_; isAnalyticFnCall_ = other.isAnalyticFnCall_; isInternalFnCall_ = other.isInternalFnCall_; - mergeAggInputFn_ = - other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); + mergeAggInputFn_ = other.mergeAggInputFn_ == null ? + null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); // Clone the params in a way that keeps the children_ and the params.exprs() // in sync. The children have already been cloned in the super c'tor. if (other.params_.isStar()) { @@ -574,7 +578,8 @@ public class FunctionCallExpr extends Expr { void validateMergeAggFn(FunctionCallExpr inputAggFn) { Preconditions.checkState(isMergeAggFn()); List copiedInputExprs = mergeAggInputFn_.getChildren(); - List inputExprs = inputAggFn.getChildren(); + List inputExprs = inputAggFn.isMergeAggFn() ? + inputAggFn.mergeAggInputFn_.getChildren() : inputAggFn.getChildren(); Preconditions.checkState(copiedInputExprs.size() == inputExprs.size()); for (int i = 0; i < inputExprs.size(); ++i) { Type copiedInputType = copiedInputExprs.get(i).getType(); diff --git a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test index a1177b0bc..b5c3970e6 100644 --- a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test +++ b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test @@ -554,6 +554,46 @@ PLAN-ROOT SINK 01:SCAN HDFS [functional.alltypes] partitions=24/24 files=24 size=478.45KB ==== +# Mixed distinct and non-distinct agg with intermediate type different from input type +# Regression test for IMPALA-5251 to exercise validateMergeAggFn() in FunctionCallExpr. +select avg(l_quantity), ndv(l_discount), count(distinct l_partkey) +from tpch_parquet.lineitem; +---- PLAN +PLAN-ROOT SINK +| +02:AGGREGATE [FINALIZE] +| output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +01:AGGREGATE +| output: avg(l_quantity), ndv(l_discount) +| group by: l_partkey +| +00:SCAN HDFS [tpch_parquet.lineitem] + partitions=1/1 files=3 size=193.74MB +---- DISTRIBUTEDPLAN +PLAN-ROOT SINK +| +06:AGGREGATE [FINALIZE] +| output: count:merge(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +05:EXCHANGE [UNPARTITIONED] +| +02:AGGREGATE +| output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +04:AGGREGATE +| output: avg:merge(l_quantity), ndv:merge(l_discount) +| group by: l_partkey +| +03:EXCHANGE [HASH(l_partkey)] +| +01:AGGREGATE [STREAMING] +| output: avg(l_quantity), ndv(l_discount) +| group by: l_partkey +| +00:SCAN HDFS [tpch_parquet.lineitem] + partitions=1/1 files=3 size=193.74MB +==== # test that aggregations are not placed below an unpartitioned exchange with a limit select count(*) from (select * from functional.alltypes limit 10) t ---- PLAN diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test b/testdata/workloads/functional-query/queries/QueryTest/uda.test index 3a9bbbec0..932b94a1b 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/uda.test +++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test @@ -88,3 +88,55 @@ from functional.decimal_tbl NULL,5 ---- TYPES decimal,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This includes distinct aggregate expression to test IMPALA-5251. +# It also relies on asserts in the UDA funciton. +select + agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"), + agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2), + agg_intermediate(int_col), + avg(c2), + min(c3-c1), + max(c1+c3), + count(distinct int_col), + sum(distinct int_col) +from + functional.alltypesagg, + functional.decimal_tiny +---- RESULTS +100,NULL,NULL,160.49989,-10.0989,11.8989,999,499500 +---- TYPES +decimal,decimal,bigint,decimal,decimal,decimal,bigint,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This includes distinct aggregate expression to test IMPALA-5251. +# It also relies on asserts in the UDA funciton. +select + agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"), + agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2), + agg_intermediate(int_col), + ndv(c2), + sum(distinct c1)/count(distinct c1) +from + functional.alltypesagg, + functional.decimal_tiny +group by + year,month,day +---- RESULTS +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +---- TYPES +decimal,decimal,bigint,bigint,decimal +==== \ No newline at end of file diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py index 56ce233a5..ec24c9f60 100644 --- a/tests/query_test/test_udfs.py +++ b/tests/query_test/test_udfs.py @@ -103,6 +103,11 @@ create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int) returns decimal(6,5) intermediate decimal(4,3) location '{location}' init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate' merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize'; + +create aggregate function {database}.agg_string_intermediate(decimal(20,10), bigint, string) +returns decimal(20,0) intermediate string location '{location}' +init_fn='AggStringIntermediateInit' update_fn='AggStringIntermediateUpdate' +merge_fn='AggStringIntermediateMerge' finalize_fn='AggStringIntermediateFinalize'; """ # Create test UDF functions in {database} from library {location}