From 42ca45e8307ba4c831ad7ac8da86bbbd957fe4cd Mon Sep 17 00:00:00 2001 From: Michael Ho Date: Tue, 25 Apr 2017 00:10:08 -0700 Subject: [PATCH] IMPALA-5251: Fix propagation of input exprs' types in 2-phase agg Since commit d2d3f4c (on asf-master), TAggregateExpr contains the logical input types of the Aggregate Expr. The reason they are included is that merging aggregate expressions will have input tyes of the intermediate values which aren't necessarily the same as the input types. For instance, NDV() uses a binary blob as its intermediate value and it's passed to its merge aggregate expressions as a StringVal but the input type of NDV() in the query could be DecimalVal. In this case, we consider DecimalVal as the logical input type while StringVal is the intermediate type. The logical input types are accessed by the BE via GetConstFnAttr() during interpretation and constant propagation during codegen. To handle distinct aggregate expressions (e.g. select count(distinct)), the FE uses 2-phase aggregation by introducing an extra phase of split/merge aggregation in which the distinct aggregate expressions' inputs are coverted and added to the group-by expressions in the first phase while the non-distinct aggregate expressions go through the normal split/merge treatement. The bug is that the existing code incorrectly propagates the intermediate types of the non-grouping aggregate expressions as the logical input types to the merging aggregate expressions in the second phase of aggregation. The input aggregate expressions for the non-distinct aggregate expressions in the second phase aggregation are already merging aggregate expressions (from phase one) in which case we should not treat its input types as logical input types. This change fixes the problem above by checking if the input aggregate expression passed to FunctionCallExpr.createMergeAggCall() is already a merging aggregate expression. If so, it will use the logical input types recorded in its 'mergeAggInputFn_' as references for its logical input types instead of the aggregate expression input types themselves. Change-Id: I158303b20d1afdff23c67f3338b9c4af2ad80691 Reviewed-on: http://gerrit.cloudera.org:8080/6724 Reviewed-by: Alex Behm Tested-by: Impala Public Jenkins --- be/src/testutil/test-udas.cc | 95 +++++++++---------- .../impala/analysis/FunctionCallExpr.java | 13 ++- .../queries/PlannerTest/aggregation.test | 40 ++++++++ .../queries/QueryTest/uda.test | 52 ++++++++++ tests/query_test/test_udfs.py | 5 + 5 files changed, 153 insertions(+), 52 deletions(-) diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc index 549f2f0cf..806a97183 100644 --- a/be/src/testutil/test-udas.cc +++ b/be/src/testutil/test-udas.cc @@ -57,36 +57,30 @@ StringVal AggFinalize(FunctionContext*, const StringVal& v) { // Defines AggIntermediate(int) returns BIGINT intermediate STRING void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {} -void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { +static void ValidateFunctionContext(const FunctionContext* context) { assert(context->GetNumArgs() == 1); assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); } +void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { + ValidateFunctionContext(context); +} void AggIntermediateInit(FunctionContext* context, StringVal*) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); } void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); } BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) { - assert(context->GetNumArgs() == 1); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); - assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); + ValidateFunctionContext(context); return BigIntVal::null(); } // Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6) // intermediate DECIMAL(3,4) // Useful to test that type parameters are plumbed through. -void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) { +static void ValidateFunctionContext2(const FunctionContext* context) { assert(context->GetNumArgs() == 2); assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); assert(context->GetArgType(0)->precision == 2); @@ -99,47 +93,52 @@ void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, c assert(context->GetReturnType().precision == 6); assert(context->GetReturnType().scale == 5); } +void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, + const IntVal&, DecimalVal*) { + ValidateFunctionContext2(context); +} void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); + ValidateFunctionContext2(context); } -void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); +void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, + DecimalVal*) { + ValidateFunctionContext2(context); } DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) { - assert(context->GetNumArgs() == 2); - assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); - assert(context->GetArgType(0)->precision == 2); - assert(context->GetArgType(0)->scale == 1); - assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); - assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetIntermediateType().precision == 4); - assert(context->GetIntermediateType().scale == 3); - assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); - assert(context->GetReturnType().precision == 6); - assert(context->GetReturnType().scale == 5); + ValidateFunctionContext2(context); return DecimalVal::null(); } +// Defines AggStringIntermediate(DECIMAL(20,10), BIGINT, STRING) returns DECIMAL(20,0) +// intermediate STRING. +// Useful to test decimal input types with string as intermediate types. +static void ValidateFunctionContext3(const FunctionContext* context) { + assert(context->GetNumArgs() == 3); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 20); + assert(context->GetArgType(0)->scale == 10); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_BIGINT); + assert(context->GetArgType(2)->type == FunctionContext::TYPE_STRING); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 20); + assert(context->GetReturnType().scale == 0); +} +void AggStringIntermediateUpdate(FunctionContext* context, const DecimalVal&, + const BigIntVal&, const StringVal&, StringVal*) { + ValidateFunctionContext3(context); +} +void AggStringIntermediateInit(FunctionContext* context, StringVal*) { + ValidateFunctionContext3(context); +} +void AggStringIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { + ValidateFunctionContext3(context); +} +DecimalVal AggStringIntermediateFinalize(FunctionContext* context, const StringVal&) { + ValidateFunctionContext3(context); + return DecimalVal(100); +} + // Defines MemTest(bigint) return bigint // "Allocates" the specified number of bytes in the update function and frees them in the // serialize function. Useful for testing mem limits. diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java index c9d098db6..1e06254c3 100644 --- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java +++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java @@ -102,8 +102,12 @@ public class FunctionCallExpr extends Expr { FunctionCallExpr agg, List params) { Preconditions.checkState(agg.isAnalyzed()); Preconditions.checkState(agg.isAggregateFunction()); + // If the input aggregate function is already a merge aggregate function (due to + // 2-phase aggregation), its input types will be the intermediate value types. The + // original input argument exprs are in 'agg.mergeAggInputFn_' so use it instead. + FunctionCallExpr mergeAggInputFn = agg.isMergeAggFn() ? agg.mergeAggInputFn_ : agg; FunctionCallExpr result = new FunctionCallExpr( - agg.fnName_, new FunctionParams(false, params), agg); + agg.fnName_, new FunctionParams(false, params), mergeAggInputFn); // Inherit the function object from 'agg'. result.fn_ = agg.fn_; result.type_ = agg.type_; @@ -127,8 +131,8 @@ public class FunctionCallExpr extends Expr { fnName_ = other.fnName_; isAnalyticFnCall_ = other.isAnalyticFnCall_; isInternalFnCall_ = other.isInternalFnCall_; - mergeAggInputFn_ = - other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); + mergeAggInputFn_ = other.mergeAggInputFn_ == null ? + null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); // Clone the params in a way that keeps the children_ and the params.exprs() // in sync. The children have already been cloned in the super c'tor. if (other.params_.isStar()) { @@ -574,7 +578,8 @@ public class FunctionCallExpr extends Expr { void validateMergeAggFn(FunctionCallExpr inputAggFn) { Preconditions.checkState(isMergeAggFn()); List copiedInputExprs = mergeAggInputFn_.getChildren(); - List inputExprs = inputAggFn.getChildren(); + List inputExprs = inputAggFn.isMergeAggFn() ? + inputAggFn.mergeAggInputFn_.getChildren() : inputAggFn.getChildren(); Preconditions.checkState(copiedInputExprs.size() == inputExprs.size()); for (int i = 0; i < inputExprs.size(); ++i) { Type copiedInputType = copiedInputExprs.get(i).getType(); diff --git a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test index a1177b0bc..b5c3970e6 100644 --- a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test +++ b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test @@ -554,6 +554,46 @@ PLAN-ROOT SINK 01:SCAN HDFS [functional.alltypes] partitions=24/24 files=24 size=478.45KB ==== +# Mixed distinct and non-distinct agg with intermediate type different from input type +# Regression test for IMPALA-5251 to exercise validateMergeAggFn() in FunctionCallExpr. +select avg(l_quantity), ndv(l_discount), count(distinct l_partkey) +from tpch_parquet.lineitem; +---- PLAN +PLAN-ROOT SINK +| +02:AGGREGATE [FINALIZE] +| output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +01:AGGREGATE +| output: avg(l_quantity), ndv(l_discount) +| group by: l_partkey +| +00:SCAN HDFS [tpch_parquet.lineitem] + partitions=1/1 files=3 size=193.74MB +---- DISTRIBUTEDPLAN +PLAN-ROOT SINK +| +06:AGGREGATE [FINALIZE] +| output: count:merge(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +05:EXCHANGE [UNPARTITIONED] +| +02:AGGREGATE +| output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount) +| +04:AGGREGATE +| output: avg:merge(l_quantity), ndv:merge(l_discount) +| group by: l_partkey +| +03:EXCHANGE [HASH(l_partkey)] +| +01:AGGREGATE [STREAMING] +| output: avg(l_quantity), ndv(l_discount) +| group by: l_partkey +| +00:SCAN HDFS [tpch_parquet.lineitem] + partitions=1/1 files=3 size=193.74MB +==== # test that aggregations are not placed below an unpartitioned exchange with a limit select count(*) from (select * from functional.alltypes limit 10) t ---- PLAN diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test b/testdata/workloads/functional-query/queries/QueryTest/uda.test index 3a9bbbec0..932b94a1b 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/uda.test +++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test @@ -88,3 +88,55 @@ from functional.decimal_tbl NULL,5 ---- TYPES decimal,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This includes distinct aggregate expression to test IMPALA-5251. +# It also relies on asserts in the UDA funciton. +select + agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"), + agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2), + agg_intermediate(int_col), + avg(c2), + min(c3-c1), + max(c1+c3), + count(distinct int_col), + sum(distinct int_col) +from + functional.alltypesagg, + functional.decimal_tiny +---- RESULTS +100,NULL,NULL,160.49989,-10.0989,11.8989,999,499500 +---- TYPES +decimal,decimal,bigint,decimal,decimal,decimal,bigint,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This includes distinct aggregate expression to test IMPALA-5251. +# It also relies on asserts in the UDA funciton. +select + agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"), + agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2), + agg_intermediate(int_col), + ndv(c2), + sum(distinct c1)/count(distinct c1) +from + functional.alltypesagg, + functional.decimal_tiny +group by + year,month,day +---- RESULTS +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +100,NULL,NULL,99,5.4994 +---- TYPES +decimal,decimal,bigint,bigint,decimal +==== \ No newline at end of file diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py index 56ce233a5..ec24c9f60 100644 --- a/tests/query_test/test_udfs.py +++ b/tests/query_test/test_udfs.py @@ -103,6 +103,11 @@ create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int) returns decimal(6,5) intermediate decimal(4,3) location '{location}' init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate' merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize'; + +create aggregate function {database}.agg_string_intermediate(decimal(20,10), bigint, string) +returns decimal(20,0) intermediate string location '{location}' +init_fn='AggStringIntermediateInit' update_fn='AggStringIntermediateUpdate' +merge_fn='AggStringIntermediateMerge' finalize_fn='AggStringIntermediateFinalize'; """ # Create test UDF functions in {database} from library {location}