From 6f31dc7f8aaabedd75badcd5d91838f339dfc481 Mon Sep 17 00:00:00 2001 From: Victor Bittorf Date: Fri, 11 Apr 2014 09:51:18 -0700 Subject: [PATCH] Adding STDDEV builtin. Change-Id: I79e5aee1e9e879aa2d09078ab45bc149675e1d4a Reviewed-on: http://gerrit.ent.cloudera.com:8080/2341 Reviewed-by: Victor Bittorf Tested-by: jenkins (cherry picked from commit a42c375d933c0b7ffe7c9b6702777679492d7ad6) Reviewed-on: http://gerrit.ent.cloudera.com:8080/2464 --- be/src/exprs/aggregate-functions.cc | 109 ++++++++++++++++++ be/src/exprs/aggregate-functions.h | 21 +++- .../com/cloudera/impala/catalog/Catalog.java | 67 +++++++++++ .../queries/QueryTest/aggregation.test | 70 +++++++++++ 4 files changed, 265 insertions(+), 2 deletions(-) diff --git a/be/src/exprs/aggregate-functions.cc b/be/src/exprs/aggregate-functions.cc index 0cdebae2e..546b59d9b 100644 --- a/be/src/exprs/aggregate-functions.cc +++ b/be/src/exprs/aggregate-functions.cc @@ -29,6 +29,17 @@ using namespace std; // the custom code in aggregation node. namespace impala { +// Converts any UDF Val Type to a string representation +template +StringVal ToStringVal(FunctionContext* context, T val) { + stringstream ss; + ss << val; + const string &str = ss.str(); + StringVal string_val(context, str.size()); + memcpy(string_val.ptr, str.c_str(), str.size()); + return string_val; +} + // Delimiter to use if the separator is NULL. static const StringVal DEFAULT_STRING_CONCAT_DELIM((uint8_t*)", ", 2); @@ -496,6 +507,91 @@ StringVal AggregateFunctions::HllFinalize(FunctionContext* ctx, const StringVal& return result_str; } +// An implementation of a simple single pass variance algorithm. A standard UDA must +// be single pass (i.e. does not scan the table more than once), so the most canonical +// two pass approach is not practical. +struct KnuthVarianceState { + double mean; + double m2; + int64_t count; +}; + +// Set pop=true for population variance, false for sample variance +double ComputeKnuthVariance(const KnuthVarianceState& state, bool pop) { + // Return zero for 1 tuple specified by + // http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions212.htm + if (state.count == 1) return 0.0; + if (pop) return state.m2 / state.count; + return state.m2 / (state.count - 1); +} + +void AggregateFunctions::KnuthVarInit(FunctionContext* ctx, StringVal* dst) { + dst->is_null = false; + dst->len = sizeof(KnuthVarianceState); + dst->ptr = ctx->Allocate(dst->len); + memset(dst->ptr, 0, dst->len); +} + +template +void AggregateFunctions::KnuthVarUpdate(FunctionContext* ctx, const T& src, + StringVal* dst) { + if (src.is_null) return; + KnuthVarianceState* state = reinterpret_cast(dst->ptr); + double temp = 1 + state->count; + double delta = src.val - state->mean; + double r = delta / temp; + state->mean += r; + state->m2 += state->count * delta * r; + state->count = temp; +} + +void AggregateFunctions::KnuthVarMerge(FunctionContext* ctx, const StringVal& src, + StringVal* dst) { + // Reference implementation: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + KnuthVarianceState* src_state = reinterpret_cast(src.ptr); + KnuthVarianceState* dst_state = reinterpret_cast(dst->ptr); + if (src_state->count == 0) return; + double delta = dst_state->mean - src_state->mean; + double sum_count = dst_state->count + src_state->count; + dst_state->mean = src_state->mean + delta * (dst_state->count / sum_count); + dst_state->m2 = (src_state->m2) + dst_state->m2 + + (delta * delta) * (src_state->count * dst_state->count / sum_count); + dst_state->count = sum_count; +} + +StringVal AggregateFunctions::KnuthVarFinalize(FunctionContext* ctx, + const StringVal& state_sv) { + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return StringVal::null(); + double variance = ComputeKnuthVariance(*state, false); + return ToStringVal(ctx, variance); +} + +StringVal AggregateFunctions::KnuthVarPopFinalize(FunctionContext* ctx, + const StringVal& state_sv) { + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return StringVal::null(); + double variance = ComputeKnuthVariance(*state, true); + return ToStringVal(ctx, variance); +} + +StringVal AggregateFunctions::KnuthStddevFinalize(FunctionContext* ctx, + const StringVal& state_sv) { + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return StringVal::null(); + double variance = ComputeKnuthVariance(*state, false); + return ToStringVal(ctx, sqrt(variance)); +} + +StringVal AggregateFunctions::KnuthStddevPopFinalize(FunctionContext* ctx, + const StringVal& state_sv) { + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return StringVal::null(); + double variance = ComputeKnuthVariance(*state, true); + return ToStringVal(ctx, sqrt(variance)); +} + // Stamp out the templates for the types we need. template void AggregateFunctions::InitZero(FunctionContext*, BigIntVal* dst); @@ -612,4 +708,17 @@ template void AggregateFunctions::HllUpdate( FunctionContext*, const TimestampVal&, StringVal*); template void AggregateFunctions::HllUpdate( FunctionContext*, const DecimalVal&, StringVal*); + +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const TinyIntVal&, StringVal*); +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const SmallIntVal&, StringVal*); +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const IntVal&, StringVal*); +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const BigIntVal&, StringVal*); +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const FloatVal&, StringVal*); +template void AggregateFunctions::KnuthVarUpdate( + FunctionContext*, const DoubleVal&, StringVal*); } diff --git a/be/src/exprs/aggregate-functions.h b/be/src/exprs/aggregate-functions.h index 71c5e7238..71125a117 100644 --- a/be/src/exprs/aggregate-functions.h +++ b/be/src/exprs/aggregate-functions.h @@ -91,9 +91,26 @@ class AggregateFunctions { static void HllUpdate(FunctionContext*, const T& src, StringVal* dst); static void HllMerge(FunctionContext*, const StringVal& src, StringVal* dst); static StringVal HllFinalize(FunctionContext*, const StringVal& src); + + // Knuth's variance algorithm, more numerically stable than canonical stddev + // algorithms; reference implementation: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + static void KnuthVarInit(FunctionContext* context, StringVal* val); + template + static void KnuthVarUpdate(FunctionContext* context, const T& input, StringVal* val); + static void KnuthVarMerge(FunctionContext* context, const StringVal& src, + StringVal* dst); + static StringVal KnuthVarFinalize(FunctionContext* context, const StringVal& val); + + // Calculates the biased variance, uses KnuthVar Init-Update-Merge functions + static StringVal KnuthVarPopFinalize(FunctionContext* context, const StringVal& val); + + // Calculates STDDEV, uses KnuthVar Init-Update-Merge functions + static StringVal KnuthStddevFinalize(FunctionContext* context, const StringVal& val); + + // Calculates the biased STDDEV, uses KnuthVar Init-Update-Merge functions + static StringVal KnuthStddevPopFinalize(FunctionContext* context, const StringVal& val); }; } - #endif - diff --git a/fe/src/main/java/com/cloudera/impala/catalog/Catalog.java b/fe/src/main/java/com/cloudera/impala/catalog/Catalog.java index 12f0de0c8..d9ab704c4 100644 --- a/fe/src/main/java/com/cloudera/impala/catalog/Catalog.java +++ b/fe/src/main/java/com/cloudera/impala/catalog/Catalog.java @@ -507,6 +507,22 @@ public abstract class Catalog { "3MaxINS_10DecimalValEEEvPN10impala_udf15FunctionContextERKT_PS6_") .build(); + private static final Map STDDEV_UPDATE_SYMBOL = + ImmutableMap.builder() + .put(ColumnType.TINYINT, + "14KnuthVarUpdateIN10impala_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(ColumnType.SMALLINT, + "14KnuthVarUpdateIN10impala_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(ColumnType.INT, + "14KnuthVarUpdateIN10impala_udf6IntValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(ColumnType.BIGINT, + "14KnuthVarUpdateIN10impala_udf9BigIntValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(ColumnType.FLOAT, + "14KnuthVarUpdateIN10impala_udf8FloatValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(ColumnType.DOUBLE, + "14KnuthVarUpdateIN10impala_udf9DoubleValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .build(); + // Populate all the aggregate builtins in the catalog. // null symbols indicate the function does not need that step of the evaluation. // An empty symbol indicates a TODO for the BE to implement the function. @@ -588,6 +604,57 @@ public abstract class Catalog { stringValSerializeOrFinalize, prefix + "12PcsaFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true)); + + if (STDDEV_UPDATE_SYMBOL.containsKey(t)) { + db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "19KnuthStddevFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev_samp", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "19KnuthStddevFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev_pop", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "22KnuthStddevPopFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + db.addBuiltin(AggregateFunction.createBuiltin(db, "variance", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "16KnuthVarFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + db.addBuiltin(AggregateFunction.createBuiltin(db, "variance_samp", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "16KnuthVarFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + db.addBuiltin(AggregateFunction.createBuiltin(db, "variance_pop", + Lists.newArrayList(t), ColumnType.STRING, ColumnType.STRING, + prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + STDDEV_UPDATE_SYMBOL.get(t), + prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "19KnuthVarPopFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true)); + } } // Sum diff --git a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test index b95412fcf..ef5fd4346 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test +++ b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test @@ -1,5 +1,75 @@ ==== ---- QUERY +# test a larger dataset, includes nulls +# the exact result could vary slightly due to numeric instability +# 0.001 is a conservative upperbound on the possible difference in results +SELECT abs(cast(variance(tinyint_col) as double) - 6.66741) < 0.001, + abs(cast(variance(double_col) as double) - 8484680) < 0.001 +from alltypesagg +---- TYPES +boolean, boolean +---- RESULTS +true,true +==== +---- QUERY +# No tuples processed (should return null) +SELECT variance(tinyint_col), stddev(smallint_col), variance_pop(int_col), +stddev_pop(bigint_col) +from alltypesagg WHERE id = -9999999 +---- TYPES +string, string, string, string +---- RESULTS +'NULL','NULL','NULL','NULL' +==== +---- QUERY +# exactly 1 tuple processed (variance & stddev are 0) +SELECT variance(tinyint_col), stddev(smallint_col), variance_pop(int_col), +stddev_pop(bigint_col) +from alltypesagg WHERE id = 1006 +---- TYPES +string, string, string, string +---- RESULTS +'0','0','0','0' +==== +---- QUERY +# Includes one row which is null +SELECT variance(tinyint_col), variance(smallint_col), variance(int_col), +variance(bigint_col), variance(float_col), variance(double_col) +from alltypesagg WHERE id >= 1000 AND id < 1006 +---- TYPES +string, string, string, string, string, string +---- RESULTS +'2.5','2.5','2.5','250','3.025','255.025' +==== +---- QUERY +SELECT variance_pop(tinyint_col), variance_pop(smallint_col), variance_pop(int_col), +variance_pop(bigint_col), variance_pop(float_col), variance_pop(double_col) +from alltypesagg WHERE id >= 1000 AND id < 1006 +---- TYPES +string, string, string, string, string, string +---- RESULTS +'2','2','2','200','2.42','204.02' +==== +---- QUERY +SELECT stddev(tinyint_col), stddev(smallint_col), stddev(int_col), stddev(bigint_col), +stddev(float_col), stddev(double_col) +from alltypesagg WHERE id >= 1000 AND id < 1006 +---- TYPES +string, string, string, string, string, string +---- RESULTS +'1.58114','1.58114','1.58114','15.8114','1.73925','15.9695' +==== +---- QUERY +# no grouping exprs, cols contain nulls except for bool cols +SELECT stddev_pop(tinyint_col), stddev_pop(smallint_col), stddev_pop(int_col), +stddev_pop(bigint_col), stddev_pop(float_col), stddev_pop(double_col) +from alltypesagg WHERE id >= 1000 AND id < 1006 +---- TYPES +string, string, string, string, string, string +---- RESULTS +'1.41421','1.41421','1.41421','14.1421','1.55563','14.2836' +==== +---- QUERY # no grouping exprs, cols contain nulls except for bool cols select count(bool_col), min(bool_col), max(bool_col) from alltypesagg