diff --git a/be/src/exprs/agg-fn-evaluator.cc b/be/src/exprs/agg-fn-evaluator.cc index 0cb430501..cc3096113 100644 --- a/be/src/exprs/agg-fn-evaluator.cc +++ b/be/src/exprs/agg-fn-evaluator.cc @@ -226,7 +226,7 @@ void AggFnEvaluator::Close(RuntimeState* state) { } } -inline void AggFnEvaluator::SetDstSlot(const AnyVal* src, +inline void AggFnEvaluator::SetDstSlot(FunctionContext* ctx, const AnyVal* src, const SlotDescriptor* dst_slot_desc, Tuple* dst) { if (src->is_null) { dst->SetNull(dst_slot_desc->null_indicator_offset()); @@ -264,6 +264,11 @@ inline void AggFnEvaluator::SetDstSlot(const AnyVal* src, *reinterpret_cast(slot) = StringValue::FromStringVal(*reinterpret_cast(src)); return; + case TYPE_CHAR: + if (slot != reinterpret_cast(src)->ptr) { + ctx->SetError("UDA should not set pointer of CHAR(N) intermediate"); + } + return; case TYPE_TIMESTAMP: *reinterpret_cast(slot) = TimestampValue::FromTimestampVal( *reinterpret_cast(src)); @@ -301,8 +306,18 @@ inline void AggFnEvaluator::SetDstSlot(const AnyVal* src, // This function would be replaced in codegen. void AggFnEvaluator::Init(FunctionContext* agg_fn_ctx, Tuple* dst) { DCHECK(init_fn_ != NULL); + if (intermediate_type().type == TYPE_CHAR) { + // For type char, we want to initialize the staging_intermediate_val_ with + // a pointer into the tuple (the UDA should not be allocating it). + void* slot = dst->GetSlot(intermediate_slot_desc_->tuple_offset()); + StringVal* sv = reinterpret_cast(staging_intermediate_val_); + sv->is_null = dst->IsNull(intermediate_slot_desc_->null_indicator_offset()); + sv->ptr = reinterpret_cast( + StringValue::CharSlotToPtr(slot, intermediate_type())); + sv->len = intermediate_type().len; + } reinterpret_cast(init_fn_)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(staging_intermediate_val_, intermediate_slot_desc_, dst); + SetDstSlot(agg_fn_ctx, staging_intermediate_val_, intermediate_slot_desc_, dst); agg_fn_ctx->impl()->set_num_updates(0); agg_fn_ctx->impl()->set_num_removes(0); } @@ -381,7 +396,7 @@ void AggFnEvaluator::Update( default: DCHECK(false) << "NYI"; } - SetDstSlot(staging_intermediate_val_, intermediate_slot_desc_, dst); + SetDstSlot(agg_fn_ctx, staging_intermediate_val_, intermediate_slot_desc_, dst); } void AggFnEvaluator::Merge(FunctionContext* agg_fn_ctx, Tuple* src, Tuple* dst) { @@ -393,7 +408,7 @@ void AggFnEvaluator::Merge(FunctionContext* agg_fn_ctx, Tuple* src, Tuple* dst) // The merge fn always takes one input argument. reinterpret_cast(merge_fn_)(agg_fn_ctx, *staging_merge_input_val_, staging_intermediate_val_); - SetDstSlot(staging_intermediate_val_, intermediate_slot_desc_, dst); + SetDstSlot(agg_fn_ctx, staging_intermediate_val_, intermediate_slot_desc_, dst); } void AggFnEvaluator::SerializeOrFinalize(FunctionContext* agg_fn_ctx, Tuple* src, @@ -420,62 +435,62 @@ void AggFnEvaluator::SerializeOrFinalize(FunctionContext* agg_fn_ctx, Tuple* src case TYPE_BOOLEAN: { typedef BooleanVal(*Fn)(FunctionContext*, AnyVal*); BooleanVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_TINYINT: { typedef TinyIntVal(*Fn)(FunctionContext*, AnyVal*); TinyIntVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_SMALLINT: { typedef SmallIntVal(*Fn)(FunctionContext*, AnyVal*); SmallIntVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_INT: { typedef IntVal(*Fn)(FunctionContext*, AnyVal*); IntVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_BIGINT: { typedef BigIntVal(*Fn)(FunctionContext*, AnyVal*); BigIntVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_FLOAT: { typedef FloatVal(*Fn)(FunctionContext*, AnyVal*); FloatVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_DOUBLE: { typedef DoubleVal(*Fn)(FunctionContext*, AnyVal*); DoubleVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_STRING: case TYPE_VARCHAR: { typedef StringVal(*Fn)(FunctionContext*, AnyVal*); StringVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_DECIMAL: { typedef DecimalVal(*Fn)(FunctionContext*, AnyVal*); DecimalVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } case TYPE_TIMESTAMP: { typedef TimestampVal(*Fn)(FunctionContext*, AnyVal*); TimestampVal v = reinterpret_cast(fn)(agg_fn_ctx, staging_intermediate_val_); - SetDstSlot(&v, dst_slot_desc, dst); + SetDstSlot(agg_fn_ctx, &v, dst_slot_desc, dst); break; } default: diff --git a/be/src/exprs/agg-fn-evaluator.h b/be/src/exprs/agg-fn-evaluator.h index e80189f09..18e1c582e 100644 --- a/be/src/exprs/agg-fn-evaluator.h +++ b/be/src/exprs/agg-fn-evaluator.h @@ -231,8 +231,8 @@ class AggFnEvaluator { const SlotDescriptor* dst_slot_desc, Tuple* dst, void* fn); // Writes the result in src into dst pointed to by dst_slot_desc - void SetDstSlot(const impala_udf::AnyVal* src, const SlotDescriptor* dst_slot_desc, - Tuple* dst); + void SetDstSlot(FunctionContext* ctx, const impala_udf::AnyVal* src, + const SlotDescriptor* dst_slot_desc, Tuple* dst); }; inline void AggFnEvaluator::Add( diff --git a/be/src/exprs/aggregate-functions.cc b/be/src/exprs/aggregate-functions.cc index 386eb4293..5dfa3457e 100644 --- a/be/src/exprs/aggregate-functions.cc +++ b/be/src/exprs/aggregate-functions.cc @@ -1087,14 +1087,15 @@ double ComputeKnuthVariance(const KnuthVarianceState& state, bool pop) { void AggregateFunctions::KnuthVarInit(FunctionContext* ctx, StringVal* dst) { dst->is_null = false; - dst->len = sizeof(KnuthVarianceState); - dst->ptr = ctx->Allocate(dst->len); + DCHECK_EQ(dst->len, sizeof(KnuthVarianceState)); memset(dst->ptr, 0, dst->len); } template void AggregateFunctions::KnuthVarUpdate(FunctionContext* ctx, const T& src, StringVal* dst) { + DCHECK(!dst->is_null); + DCHECK_EQ(dst->len, sizeof(KnuthVarianceState)); if (src.is_null) return; KnuthVarianceState* state = reinterpret_cast(dst->ptr); double temp = 1 + state->count; @@ -1107,6 +1108,10 @@ void AggregateFunctions::KnuthVarUpdate(FunctionContext* ctx, const T& src, void AggregateFunctions::KnuthVarMerge(FunctionContext* ctx, const StringVal& src, StringVal* dst) { + DCHECK(!dst->is_null); + DCHECK_EQ(dst->len, sizeof(KnuthVarianceState)); + DCHECK(!src.is_null); + DCHECK_EQ(src.len, sizeof(KnuthVarianceState)); // Reference implementation: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm KnuthVarianceState* src_state = reinterpret_cast(src.ptr); @@ -1128,34 +1133,31 @@ DoubleVal AggregateFunctions::KnuthVarFinalize( return DoubleVal(variance); } -StringVal AggregateFunctions::KnuthVarPopFinalize(FunctionContext* ctx, +DoubleVal AggregateFunctions::KnuthVarPopFinalize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); - KnuthVarianceState state = *reinterpret_cast(state_sv.ptr); - ctx->Free(state_sv.ptr); - if (state.count == 0) return StringVal::null(); - double variance = ComputeKnuthVariance(state, true); - return ToStringVal(ctx, variance); + DCHECK_EQ(state_sv.len, sizeof(KnuthVarianceState)); + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return DoubleVal::null(); + return ComputeKnuthVariance(*state, true); } -StringVal AggregateFunctions::KnuthStddevFinalize(FunctionContext* ctx, +DoubleVal AggregateFunctions::KnuthStddevFinalize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); - KnuthVarianceState state = *reinterpret_cast(state_sv.ptr); - ctx->Free(state_sv.ptr); - if (state.count == 0) return StringVal::null(); - double variance = ComputeKnuthVariance(state, false); - return ToStringVal(ctx, sqrt(variance)); + DCHECK_EQ(state_sv.len, sizeof(KnuthVarianceState)); + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return DoubleVal::null(); + return sqrt(ComputeKnuthVariance(*state, false)); } -StringVal AggregateFunctions::KnuthStddevPopFinalize(FunctionContext* ctx, +DoubleVal AggregateFunctions::KnuthStddevPopFinalize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); - KnuthVarianceState state = *reinterpret_cast(state_sv.ptr); - ctx->Free(state_sv.ptr); - if (state.count == 0) return StringVal::null(); - double variance = ComputeKnuthVariance(state, true); - return ToStringVal(ctx, sqrt(variance)); + DCHECK_EQ(state_sv.len, sizeof(KnuthVarianceState)); + KnuthVarianceState* state = reinterpret_cast(state_sv.ptr); + if (state->count == 0) return DoubleVal::null(); + return sqrt(ComputeKnuthVariance(*state, true)); } struct RankState { diff --git a/be/src/exprs/aggregate-functions.h b/be/src/exprs/aggregate-functions.h index 619909538..294e194e6 100644 --- a/be/src/exprs/aggregate-functions.h +++ b/be/src/exprs/aggregate-functions.h @@ -189,13 +189,13 @@ class AggregateFunctions { static DoubleVal KnuthVarFinalize(FunctionContext* context, const StringVal& val); // Calculates the biased variance, uses KnuthVar Init-Update-Merge functions - static StringVal KnuthVarPopFinalize(FunctionContext* context, const StringVal& val); + static DoubleVal KnuthVarPopFinalize(FunctionContext* context, const StringVal& val); // Calculates STDDEV, uses KnuthVar Init-Update-Merge functions - static StringVal KnuthStddevFinalize(FunctionContext* context, const StringVal& val); + static DoubleVal KnuthStddevFinalize(FunctionContext* context, const StringVal& val); // Calculates the biased STDDEV, uses KnuthVar Init-Update-Merge functions - static StringVal KnuthStddevPopFinalize(FunctionContext* context, const StringVal& val); + static DoubleVal KnuthStddevPopFinalize(FunctionContext* context, const StringVal& val); // ----------------------------- Analytic Functions --------------------------------- diff --git a/fe/src/main/java/com/cloudera/impala/catalog/BuiltinsDb.java b/fe/src/main/java/com/cloudera/impala/catalog/BuiltinsDb.java index 5f4a95f22..43c5ac118 100644 --- a/fe/src/main/java/com/cloudera/impala/catalog/BuiltinsDb.java +++ b/fe/src/main/java/com/cloudera/impala/catalog/BuiltinsDb.java @@ -574,51 +574,51 @@ public class BuiltinsDb extends Db { if (STDDEV_UPDATE_SYMBOL.containsKey(t)) { db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev", - Lists.newArrayList(t), Type.STRING, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "19KnuthStddevFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev_samp", - Lists.newArrayList(t), Type.STRING, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "19KnuthStddevFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); db.addBuiltin(AggregateFunction.createBuiltin(db, "stddev_pop", - Lists.newArrayList(t), Type.STRING, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "22KnuthStddevPopFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); db.addBuiltin(AggregateFunction.createBuiltin(db, "variance", - Lists.newArrayList(t), Type.DOUBLE, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "16KnuthVarFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); db.addBuiltin(AggregateFunction.createBuiltin(db, "variance_samp", - Lists.newArrayList(t), Type.STRING, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "16KnuthVarFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); db.addBuiltin(AggregateFunction.createBuiltin(db, "variance_pop", - Lists.newArrayList(t), Type.STRING, Type.STRING, + Lists.newArrayList(t), Type.DOUBLE, ScalarType.createCharType(24), prefix + "12KnuthVarInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", prefix + STDDEV_UPDATE_SYMBOL.get(t), prefix + "13KnuthVarMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", - stringValSerializeOrFinalize, + null, prefix + "19KnuthVarPopFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", true, false, false)); } diff --git a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test index c8aac8463..488139148 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test +++ b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test @@ -17,9 +17,9 @@ SELECT variance(tinyint_col), stddev(smallint_col), variance_pop(int_col), stddev_pop(bigint_col) from alltypesagg WHERE id = -9999999 ---- RESULTS -NULL,'NULL','NULL','NULL' +NULL,NULL,NULL,NULL ---- TYPES -double, string, string, string +double, double, double, double ==== ---- QUERY # exactly 1 tuple processed (variance & stddev are 0) @@ -27,9 +27,9 @@ SELECT variance(tinyint_col), stddev(smallint_col), variance_pop(int_col), stddev_pop(bigint_col) from alltypesagg WHERE id = 1006 ---- RESULTS -0,'0','0','0' +0,0,0,0 ---- TYPES -double, string, string, string +double, double, double, double ==== ---- QUERY # Includes one row which is null @@ -46,28 +46,36 @@ SELECT variance_pop(tinyint_col), variance_pop(smallint_col), variance_pop(int_c variance_pop(bigint_col), variance_pop(float_col), variance_pop(double_col) from alltypesagg WHERE id >= 1000 AND id < 1006 ---- RESULTS -'2','2','2','200','2.42','204.02' +2,2,2,200,2.42,204.02 ---- TYPES -string, string, string, string, string, string +double, double, double, double, double, double ==== ---- QUERY -SELECT stddev(tinyint_col), stddev(smallint_col), stddev(int_col), stddev(bigint_col), -stddev(float_col), stddev(double_col) +SELECT round(stddev(tinyint_col), 5), + round(stddev(smallint_col), 5), + round(stddev(int_col), 5), + round(stddev(bigint_col), 5), + round(stddev(float_col), 5), + round(stddev(double_col), 5) from alltypesagg WHERE id >= 1000 AND id < 1006 ---- RESULTS -'1.58114','1.58114','1.58114','15.8114','1.73925','15.9695' +1.58114,1.58114,1.58114,15.81139,1.73925,15.96950 ---- TYPES -string, string, string, string, string, string +double, double, double, double, double, double ==== ---- QUERY # no grouping exprs, cols contain nulls except for bool cols -SELECT stddev_pop(tinyint_col), stddev_pop(smallint_col), stddev_pop(int_col), -stddev_pop(bigint_col), stddev_pop(float_col), stddev_pop(double_col) +SELECT round(stddev_pop(tinyint_col), 5), + round(stddev_pop(smallint_col), 5), + round(stddev_pop(int_col), 5), + round(stddev_pop(bigint_col), 5), + round(stddev_pop(float_col), 5), + round(stddev_pop(double_col), 5) from alltypesagg WHERE id >= 1000 AND id < 1006 ---- RESULTS -'1.41421','1.41421','1.41421','14.1421','1.55563','14.2836' +1.41421,1.41421,1.41421,14.14214,1.55563,14.28356 ---- TYPES -string, string, string, string, string, string +double, double, double, double, double, double ==== ---- QUERY # no grouping exprs, cols contain nulls except for bool cols