diff --git a/be/src/codegen/codegen-anyval.cc b/be/src/codegen/codegen-anyval.cc index bd2740984..a7788123d 100644 --- a/be/src/codegen/codegen-anyval.cc +++ b/be/src/codegen/codegen-anyval.cc @@ -67,7 +67,7 @@ Type* CodegenAnyVal::GetLoweredType(LlvmCodeGen* cg, const ColumnType& type) { } } -Type* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { +PointerType* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { return GetLoweredType(cg, type)->getPointerTo(); } @@ -116,7 +116,7 @@ Type* CodegenAnyVal::GetUnloweredType(LlvmCodeGen* cg, const ColumnType& type) { return result; } -Type* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { +PointerType* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { return GetUnloweredType(cg, type)->getPointerTo(); } diff --git a/be/src/codegen/codegen-anyval.h b/be/src/codegen/codegen-anyval.h index c07f3eb16..13494acda 100644 --- a/be/src/codegen/codegen-anyval.h +++ b/be/src/codegen/codegen-anyval.h @@ -95,7 +95,7 @@ class CodegenAnyVal { /// Returns the lowered AnyVal pointer type associated with 'type'. /// E.g.: TYPE_BOOLEAN => i16* - static llvm::Type* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type); + static llvm::PointerType* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type); /// Returns the unlowered AnyVal type associated with 'type'. /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal" @@ -103,7 +103,7 @@ class CodegenAnyVal { /// Returns the unlowered AnyVal pointer type associated with 'type'. /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"* - static llvm::Type* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type); + static llvm::PointerType* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type); /// Return the constant type-lowered value corresponding to a null *Val. /// E.g.: passing TYPE_DOUBLE (corresponding to the lowered DoubleVal { i8, double }) diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc index 3d730d5ad..6cf43f6da 100644 --- a/be/src/codegen/llvm-codegen.cc +++ b/be/src/codegen/llvm-codegen.cc @@ -70,6 +70,7 @@ #include "util/hdfs-util.h" #include "util/path-builder.h" #include "util/runtime-profile-counters.h" +#include "util/symbols-util.h" #include "util/test-info.h" #include "common/names.h" @@ -766,8 +767,116 @@ Function* LlvmCodeGen::FnPrototype::GeneratePrototype( return fn; } -int LlvmCodeGen::ReplaceCallSites(Function* caller, Function* new_fn, - const string& target_name) { +Status LlvmCodeGen::LoadFunction(const TFunction& fn, const std::string& symbol, + const ColumnType* return_type, const std::vector& arg_types, + int num_fixed_args, bool has_varargs, Function** llvm_fn, + LibCacheEntry** cache_entry) { + DCHECK_GE(arg_types.size(), num_fixed_args); + DCHECK(has_varargs || arg_types.size() == num_fixed_args); + DCHECK(!has_varargs || arg_types.size() > num_fixed_args); + // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd. + // TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd + // code. Always use the interpreted version of these functions. + // TODO: fix these built-in functions so we don't need 'broken_builtin' below. + bool broken_builtin = fn.name.function_name == "from_utc_timestamp" + || fn.name.function_name == "to_utc_timestamp" + || symbol.find("AddSub") != string::npos; + if (fn.binary_type == TFunctionBinaryType::NATIVE + || (fn.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) { + // In this path, we are calling a precompiled native function, either a UDF + // in a .so or a builtin using the UDF interface. + void* fn_ptr; + Status status = LibCache::instance()->GetSoFunctionPtr( + fn.hdfs_location, symbol, &fn_ptr, cache_entry); + if (!status.ok() && fn.binary_type == TFunctionBinaryType::BUILTIN) { + // Builtins symbols should exist unless there is a version mismatch. + status.AddDetail( + ErrorMsg(TErrorCode::MISSING_BUILTIN, fn.name.function_name, symbol).msg()); + } + RETURN_IF_ERROR(status); + DCHECK(fn_ptr != NULL); + + // Per the x64 ABI, DecimalVals are returned via a DecimalVal* output argument. + // So, the return type is void. + bool is_decimal = return_type != NULL && return_type->type == TYPE_DECIMAL; + Type* llvm_return_type = return_type == NULL || is_decimal ? + void_type() : + CodegenAnyVal::GetLoweredType(this, *return_type); + + // Convert UDF function pointer to Function*. Start by creating a function + // prototype for it. + FnPrototype prototype(this, symbol, llvm_return_type); + + if (is_decimal) { + // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument + Type* output_type = CodegenAnyVal::GetUnloweredPtrType(this, *return_type); + prototype.AddArgument("output", output_type); + } + + // The "FunctionContext*" argument. + prototype.AddArgument("ctx", GetPtrType("class.impala_udf::FunctionContext")); + + // The "fixed" arguments for the UDF function, followed by the variable arguments, + // if any. + for (int i = 0; i < num_fixed_args; ++i) { + Type* arg_type = CodegenAnyVal::GetUnloweredPtrType(this, arg_types[i]); + prototype.AddArgument(Substitute("fixed_arg_$0", i), arg_type); + } + + if (has_varargs) { + prototype.AddArgument("num_var_arg", GetType(TYPE_INT)); + // Get the vararg type from the first vararg. + prototype.AddArgument( + "var_arg", CodegenAnyVal::GetUnloweredPtrType(this, arg_types[num_fixed_args])); + } + + // Create a Function* with the generated type. This is only a function + // declaration, not a definition, since we do not create any basic blocks or + // instructions in it. + *llvm_fn = prototype.GeneratePrototype(NULL, NULL, false); + + // Associate the dynamically loaded function pointer with the Function* we defined. + // This tells LLVM where the compiled function definition is located in memory. + execution_engine_->addGlobalMapping(*llvm_fn, fn_ptr); + } else if (fn.binary_type == TFunctionBinaryType::BUILTIN) { + // In this path, we're running a builtin with the UDF interface. The IR is + // in the llvm module. Builtin functions may use Expr::GetConstant(). Clone the + // function so that we can replace constants in the copied function. + *llvm_fn = GetFunction(symbol, true); + if (*llvm_fn == NULL) { + // Builtins symbols should exist unless there is a version mismatch. + return Status(Substitute("Builtin '$0' with symbol '$1' does not exist. Verify " + "that all your impalads are the same version.", + fn.name.function_name, symbol)); + } + // Rename the function to something more readable than the mangled name. + string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str()); + (*llvm_fn)->setName(demangled_name); + } else { + // We're running an IR UDF. + DCHECK_EQ(fn.binary_type, TFunctionBinaryType::IR); + + string local_path; + RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath( + fn.hdfs_location, LibCache::TYPE_IR, &local_path)); + // Link the UDF module into this query's main module so the UDF's functions are + // available in the main module. + RETURN_IF_ERROR(LinkModule(local_path)); + + *llvm_fn = GetFunction(symbol, true); + if (*llvm_fn == NULL) { + return Status(Substitute("Unable to load function '$0' from LLVM module '$1'", + symbol, fn.hdfs_location)); + } + // Rename the function to something more readable than the mangled name. + string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str()); + (*llvm_fn)->setName(demangled_name); + } + return Status::OK(); +} + +int LlvmCodeGen::ReplaceCallSites( + Function* caller, Function* new_fn, const string& target_name) { DCHECK(!is_compiled_); DCHECK(caller->getParent() == module_); DCHECK(caller != NULL); diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h index 961566484..d51faabe7 100644 --- a/be/src/codegen/llvm-codegen.h +++ b/be/src/codegen/llvm-codegen.h @@ -294,6 +294,23 @@ class LlvmCodeGen { /// functions. Status FinalizeModule(); + /// Loads a native or IR function 'fn' with symbol 'symbol' from the builtins or + /// an external library and puts the result in *llvm_fn. *llvm_fn can be safely + /// modified in place, because it is either newly generated or cloned. The caller must + /// call FinalizeFunction() on 'llvm_fn' once it is done modifying it. The function has + /// return type 'return_type' (void if 'return_type' is NULL) and input argument types + /// 'arg_types'. The first 'num_fixed_args' arguments are fixed arguments, and the + /// remaining arguments are varargs. 'has_varargs' indicates whether the function + /// accepts varargs. If 'has_varargs' is true, there must be at least one vararg. If + /// the function is loaded from a library, 'cache_entry' is updated to point to the + /// library containing the function. If 'cache_entry' is set to a non-NULL value by + /// this function, the caller must call LibCache::DecrementUseCount() on it when done + /// using the function. + Status LoadFunction(const TFunction& fn, const std::string& symbol, + const ColumnType* return_type, const std::vector& arg_types, + int num_fixed_args, bool has_varargs, llvm::Function** llvm_fn, + LibCacheEntry** cache_entry); + /// Replaces all instructions in 'caller' that call 'target_name' with a call /// instruction to 'new_fn'. Returns the number of call sites updated. /// @@ -485,10 +502,6 @@ class LlvmCodeGen { llvm::Value* CodegenArrayAt( LlvmBuilder*, llvm::Value* array, int idx, const char* name = ""); - /// Loads a module at 'file' and links it to the module associated with - /// this LlvmCodeGen object. The module must be on the local filesystem. - Status LinkModule(const std::string& file); - /// If there are more than this number of expr trees (or functions that evaluate /// expressions), avoid inlining avoid inlining for the exprs exceeding this threshold. static const int CODEGEN_INLINE_EXPRS_THRESHOLD = 100; @@ -538,6 +551,10 @@ class LlvmCodeGen { Status LoadModuleFromMemory(std::unique_ptr module_ir_buf, std::string module_name, std::unique_ptr* module); + /// Loads a module at 'file' and links it to the module associated with + /// this LlvmCodeGen object. The module must be on the local filesystem. + Status LinkModule(const std::string& file); + /// Strip global constructors and destructors from an LLVM module. We never run them /// anyway (they must be explicitly invoked) so it is dead code. static void StripGlobalCtorsDtors(llvm::Module* module); diff --git a/be/src/exec/partitioned-aggregation-node.cc b/be/src/exec/partitioned-aggregation-node.cc index 6cc36bece..7c75d01f4 100644 --- a/be/src/exec/partitioned-aggregation-node.cc +++ b/be/src/exec/partitioned-aggregation-node.cc @@ -1721,22 +1721,8 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, const vector& input_vals, const CodegenAnyVal& dst, CodegenAnyVal* updated_dst_val) { DCHECK_EQ(evaluator->input_expr_ctxs().size(), input_vals.size()); - const string& symbol = - evaluator->is_merge() ? evaluator->merge_symbol() : evaluator->update_symbol(); - const ColumnType& dst_type = evaluator->intermediate_type(); - - // TODO: to support actual UDAs, not just builtin functions using the UDA interface, - // we need to load the function at this point. - Function* uda_fn = codegen->GetFunction(symbol, true); - DCHECK(uda_fn != NULL); - - vector arg_types; - for (int i = 0; i < evaluator->input_expr_ctxs().size(); ++i) { - arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc( - evaluator->input_expr_ctxs()[i]->root()->type())); - } - Expr::InlineConstants( - AnyValUtil::ColumnTypeToTypeDesc(dst_type), arg_types, codegen, uda_fn); + Function* uda_fn; + RETURN_IF_ERROR(evaluator->GetUpdateOrMergeFunction(codegen, &uda_fn)); // Set up arguments for call to UDA, which are the FunctionContext*, followed by // pointers to all input values, followed by a pointer to the destination value. @@ -1753,6 +1739,7 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, // Create pointer to dst to pass to uda_fn. We must use the unlowered type for the // same reason as above. Value* dst_lowered_ptr = dst.GetLoweredPtr("dst_lowered_ptr"); + const ColumnType& dst_type = evaluator->intermediate_type(); Type* dst_unlowered_ptr_type = CodegenAnyVal::GetUnloweredPtrType(codegen, dst_type); Value* dst_unlowered_ptr = builder->CreateBitCast( dst_lowered_ptr, dst_unlowered_ptr_type, "dst_unlowered_ptr"); @@ -1825,13 +1812,6 @@ Status PartitionedAggregationNode::CodegenUpdateTuple( "intermediate tuple desc"); } - for (AggFnEvaluator* evaluator : aggregate_evaluators_) { - // Don't codegen things that aren't builtins (for now) - if (!evaluator->is_builtin()) { - return Status("PartitionedAggregationNode::CodegenUpdateTuple(): UDA codegen NYI"); - } - } - // Get the types to match the UpdateTuple signature Type* agg_node_type = codegen->GetType(PartitionedAggregationNode::LLVM_CLASS_NAME); Type* fn_ctx_type = codegen->GetType(FunctionContextImpl::LLVM_FUNCTIONCONTEXT_NAME); diff --git a/be/src/exprs/agg-fn-evaluator.cc b/be/src/exprs/agg-fn-evaluator.cc index 4c6a99382..b62313076 100644 --- a/be/src/exprs/agg-fn-evaluator.cc +++ b/be/src/exprs/agg-fn-evaluator.cc @@ -23,8 +23,9 @@ #include "common/logging.h" #include "exec/aggregation-node.h" #include "exprs/aggregate-functions.h" -#include "exprs/expr-context.h" #include "exprs/anyval-util.h" +#include "exprs/expr-context.h" +#include "exprs/scalar-fn-call.h" #include "runtime/lib-cache.h" #include "runtime/raw-value.h" #include "runtime/runtime-state.h" @@ -94,6 +95,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn) is_analytic_fn_(is_analytic_fn), intermediate_slot_desc_(NULL), output_slot_desc_(NULL), + arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs( + ColumnType::FromThrift(desc.agg_expr.arg_types))), cache_entry_(NULL), init_fn_(NULL), update_fn_(NULL), @@ -198,28 +201,15 @@ Status AggFnEvaluator::Prepare(RuntimeState* state, const RowDescriptor& desc, &cache_entry_)); } if (!fn_.aggregate_fn.remove_fn_symbol.empty()) { - RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, - &cache_entry_)); + RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, + fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, &cache_entry_)); } if (!fn_.aggregate_fn.finalize_fn_symbol.empty()) { - RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, - &cache_entry_)); + RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, + fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, &cache_entry_)); } - - vector arg_types; - for (int i = 0; i < input_expr_ctxs_.size(); ++i) { - arg_types.push_back( - AnyValUtil::ColumnTypeToTypeDesc(input_expr_ctxs_[i]->root()->type())); - } - - FunctionContext::TypeDesc intermediate_type = - AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type()); - FunctionContext::TypeDesc output_type = - AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type()); - *agg_fn_ctx = FunctionContextImpl::CreateContext( - state, agg_fn_pool, intermediate_type, output_type, arg_types); + *agg_fn_ctx = FunctionContextImpl::CreateContext(state, agg_fn_pool, + GetIntermediateTypeDesc(), GetOutputTypeDesc(), arg_type_descs_); return Status::OK(); } @@ -521,6 +511,40 @@ void AggFnEvaluator::SerializeOrFinalize(FunctionContext* agg_fn_ctx, Tuple* src } } +/// Gets the update or merge function for this UDA. +Status AggFnEvaluator::GetUpdateOrMergeFunction(LlvmCodeGen* codegen, Function** uda_fn) { + const string& symbol = + is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : fn_.aggregate_fn.update_fn_symbol; + vector fn_arg_types; + for (ExprContext* input_expr_ctx : input_expr_ctxs_) { + fn_arg_types.push_back(input_expr_ctx->root()->type()); + } + // The intermediate value is passed as the last argument. + fn_arg_types.push_back(intermediate_type()); + RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, NULL, fn_arg_types, + fn_arg_types.size(), false, uda_fn, &cache_entry_)); + + // Inline constants into the function body (if there is an IR body). + if (!(*uda_fn)->isDeclaration()) { + // TODO: IMPALA-4785: we should also replace references to GetIntermediateType() + // with constants. + Expr::InlineConstants(GetOutputTypeDesc(), arg_type_descs_, codegen, *uda_fn); + *uda_fn = codegen->FinalizeFunction(*uda_fn); + if (*uda_fn == NULL) { + return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location); + } + } + return Status::OK(); +} + +FunctionContext::TypeDesc AggFnEvaluator::GetIntermediateTypeDesc() const { + return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type()); +} + +FunctionContext::TypeDesc AggFnEvaluator::GetOutputTypeDesc() const { + return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type()); +} + string AggFnEvaluator::DebugString(const vector& exprs) { stringstream out; out << "["; diff --git a/be/src/exprs/agg-fn-evaluator.h b/be/src/exprs/agg-fn-evaluator.h index b3ecda015..712bf407b 100644 --- a/be/src/exprs/agg-fn-evaluator.h +++ b/be/src/exprs/agg-fn-evaluator.h @@ -119,8 +119,6 @@ class AggFnEvaluator { bool SupportsRemove() const { return remove_fn_ != NULL; } bool SupportsSerialize() const { return serialize_fn_ != NULL; } const std::string& fn_name() const { return fn_.name.function_name; } - const std::string& update_symbol() const { return fn_.aggregate_fn.update_fn_symbol; } - const std::string& merge_symbol() const { return fn_.aggregate_fn.merge_fn_symbol; } const SlotDescriptor* output_slot_desc() const { return output_slot_desc_; } static std::string DebugString(const std::vector& exprs); @@ -168,14 +166,8 @@ class AggFnEvaluator { static void Finalize(const std::vector& evaluators, const std::vector& fn_ctxs, Tuple* src, Tuple* dst); - /// TODO: implement codegen path. These functions would return IR functions with - /// the same signature as the interpreted ones above. - /// Function* GetIrInitFn(); - /// Function* GetIrAddFn(); - /// Function* GetIrRemoveFn(); - /// Function* GetIrSerializeFn(); - /// Function* GetIrGetValueFn(); - /// Function* GetIrFinalizeFn(); + /// Gets the codegened update or merge function for this aggregate function. + Status GetUpdateOrMergeFunction(LlvmCodeGen* codegen, llvm::Function** uda_fn); private: const TFunction fn_; @@ -195,6 +187,9 @@ class AggFnEvaluator { /// expression (e.g. count(*)). std::vector input_expr_ctxs_; + /// The types of the arguments to the aggregate function. + const std::vector arg_type_descs_; + /// The enum for some of the builtins that still require special cased logic. AggregationOp agg_op_; @@ -221,6 +216,12 @@ class AggFnEvaluator { /// Use Create() instead. AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn); + /// Return the intermediate type of the aggregate function. + FunctionContext::TypeDesc GetIntermediateTypeDesc() const; + + /// Return the output type of the aggregate function. + FunctionContext::TypeDesc GetOutputTypeDesc() const; + /// TODO: these functions below are not extensible and we need to use codegen to /// generate the calls into the UDA functions (like for UDFs). /// Remove these functions when this is supported. diff --git a/be/src/exprs/anyval-util.cc b/be/src/exprs/anyval-util.cc index 132d6e43a..c49cdb3a5 100644 --- a/be/src/exprs/anyval-util.cc +++ b/be/src/exprs/anyval-util.cc @@ -92,17 +92,33 @@ FunctionContext::TypeDesc AnyValUtil::ColumnTypeToTypeDesc(const ColumnType& typ return out; } +vector AnyValUtil::ColumnTypesToTypeDescs( + const vector& types) { + vector type_descs; + for (const ColumnType& type : types) type_descs.push_back(ColumnTypeToTypeDesc(type)); + return type_descs; +} + ColumnType AnyValUtil::TypeDescToColumnType(const FunctionContext::TypeDesc& type) { switch (type.type) { - case FunctionContext::TYPE_BOOLEAN: return ColumnType(TYPE_BOOLEAN); - case FunctionContext::TYPE_TINYINT: return ColumnType(TYPE_TINYINT); - case FunctionContext::TYPE_SMALLINT: return ColumnType(TYPE_SMALLINT); - case FunctionContext::TYPE_INT: return ColumnType(TYPE_INT); - case FunctionContext::TYPE_BIGINT: return ColumnType(TYPE_BIGINT); - case FunctionContext::TYPE_FLOAT: return ColumnType(TYPE_FLOAT); - case FunctionContext::TYPE_DOUBLE: return ColumnType(TYPE_DOUBLE); - case FunctionContext::TYPE_TIMESTAMP: return ColumnType(TYPE_TIMESTAMP); - case FunctionContext::TYPE_STRING: return ColumnType(TYPE_STRING); + case FunctionContext::TYPE_BOOLEAN: + return ColumnType(TYPE_BOOLEAN); + case FunctionContext::TYPE_TINYINT: + return ColumnType(TYPE_TINYINT); + case FunctionContext::TYPE_SMALLINT: + return ColumnType(TYPE_SMALLINT); + case FunctionContext::TYPE_INT: + return ColumnType(TYPE_INT); + case FunctionContext::TYPE_BIGINT: + return ColumnType(TYPE_BIGINT); + case FunctionContext::TYPE_FLOAT: + return ColumnType(TYPE_FLOAT); + case FunctionContext::TYPE_DOUBLE: + return ColumnType(TYPE_DOUBLE); + case FunctionContext::TYPE_TIMESTAMP: + return ColumnType(TYPE_TIMESTAMP); + case FunctionContext::TYPE_STRING: + return ColumnType(TYPE_STRING); case FunctionContext::TYPE_DECIMAL: return ColumnType::CreateDecimalType(type.precision, type.scale); case FunctionContext::TYPE_FIXED_BUFFER: diff --git a/be/src/exprs/anyval-util.h b/be/src/exprs/anyval-util.h index e5473c7b7..429322a6b 100644 --- a/be/src/exprs/anyval-util.h +++ b/be/src/exprs/anyval-util.h @@ -227,6 +227,8 @@ class AnyValUtil { } static FunctionContext::TypeDesc ColumnTypeToTypeDesc(const ColumnType& type); + static std::vector ColumnTypesToTypeDescs( + const std::vector& types); // Note: constructing a ColumnType is expensive and should be avoided in query execution // paths (i.e. non-setup paths). static ColumnType TypeDescToColumnType(const FunctionContext::TypeDesc& type); diff --git a/be/src/exprs/expr.cc b/be/src/exprs/expr.cc index 2ffddf6d2..f119a9d53 100644 --- a/be/src/exprs/expr.cc +++ b/be/src/exprs/expr.cc @@ -637,10 +637,10 @@ int Expr::InlineConstants(LlvmCodeGen* codegen, Function* fn) { } int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type, - const std::vector& arg_types, LlvmCodeGen* codegen, - Function* fn) { + const std::vector& arg_types, LlvmCodeGen* codegen, + Function* fn) { int replaced = 0; - for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end; ) { + for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end;) { // Increment iter now so we don't mess it up modifying the instruction below Instruction* instr = &*(iter++); @@ -666,7 +666,7 @@ int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type, int i_val = static_cast(i_arg->getSExtValue()); // All supported constants are currently integers. call_instr->replaceAllUsesWith(ConstantInt::get(codegen->GetType(TYPE_INT), - GetConstantInt(return_type, arg_types, c_val, i_val))); + GetConstantInt(return_type, arg_types, c_val, i_val))); call_instr->eraseFromParent(); ++replaced; } diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index b77c9a29b..13ba31267 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -263,9 +263,11 @@ class Expr { // Any additions to this enum must be reflected in both GetConstant*() and // GetIrConstant(). enum ExprConstant { + // RETURN_TYPE_*: properties of FunctionContext::GetReturnType(). RETURN_TYPE_SIZE, // int RETURN_TYPE_PRECISION, // int RETURN_TYPE_SCALE, // int + // ARG_TYPE_* with parameter i: properties of FunctionContext::GetArgType(i). ARG_TYPE_SIZE, // int[] ARG_TYPE_PRECISION, // int[] ARG_TYPE_SCALE, // int[] @@ -289,7 +291,8 @@ class Expr { // constants to be replaced must be inlined into the function that InlineConstants() // is run on (e.g. by annotating them with IR_ALWAYS_INLINE). // - // TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in loops + // TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in + // loops static int GetConstantInt(const FunctionContext& ctx, ExprConstant c, int i = -1); /// Finds all calls to Expr::GetConstantInt() in 'fn' and replaces them with the @@ -298,8 +301,8 @@ class Expr { /// 'arg_types' are the argument types of the UDF or UDAF, i.e. the values of /// FunctionContext::GetArgType(). static int InlineConstants(const FunctionContext::TypeDesc& return_type, - const std::vector& arg_types, - LlvmCodeGen* codegen, llvm::Function* fn); + const std::vector& arg_types, LlvmCodeGen* codegen, + llvm::Function* fn); static const char* LLVM_CLASS_NAME; diff --git a/be/src/exprs/scalar-fn-call.cc b/be/src/exprs/scalar-fn-call.cc index c7d4e7ed5..06830b750 100644 --- a/be/src/exprs/scalar-fn-call.cc +++ b/be/src/exprs/scalar-fn-call.cc @@ -19,12 +19,12 @@ #include #include -#include #include +#include #include -#include #include +#include #include #include "codegen/codegen-anyval.h" @@ -37,8 +37,6 @@ #include "runtime/types.h" #include "udf/udf-internal.h" #include "util/debug-util.h" -#include "util/dynamic-util.h" -#include "util/symbols-util.h" #include "common/names.h" @@ -311,20 +309,27 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { } } - if (fn_.binary_type == TFunctionBinaryType::IR) { - string local_path; - RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath( - fn_.hdfs_location, LibCache::TYPE_IR, &local_path)); - // Link the UDF module into this query's main module (essentially copy the UDF - // module into the main module) so the UDF's functions are available in the main - // module. - RETURN_IF_ERROR(codegen->LinkModule(local_path)); - // Load the Prepare() and Close() functions from the LLVM module. - RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen)); + vector arg_types; + for (const Expr* child : children_) arg_types.push_back(child->type()); + Function* udf; + RETURN_IF_ERROR(codegen->LoadFunction(fn_, fn_.scalar_fn.symbol, &type_, arg_types, + NumFixedArgs(), vararg_start_idx_ != -1, &udf, &cache_entry_)); + // Inline constants into the function if it has an IR body. + if (!udf->isDeclaration()) { + InlineConstants(AnyValUtil::ColumnTypeToTypeDesc(type_), + AnyValUtil::ColumnTypesToTypeDescs(arg_types), codegen, udf); + udf = codegen->FinalizeFunction(udf); + if (udf == NULL) { + return Status( + TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location); + } } - Function* udf; - RETURN_IF_ERROR(GetUdf(codegen, &udf)); + if (fn_.binary_type == TFunctionBinaryType::IR) { + // LoadFunction() should have linked the IR module into 'codegen'. Now load the + // Prepare() and Close() functions from 'codegen'. + RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen)); + } // Create wrapper that computes args and calls UDF stringstream fn_name; @@ -407,8 +412,7 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { // Add the number of varargs udf_args.push_back(codegen->GetIntConstant(TYPE_INT, NumVarArgs())); // Add all the accumulated vararg inputs as one input argument. - PointerType* vararg_type = - codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, VarArgsType())); + PointerType* vararg_type = CodegenAnyVal::GetUnloweredPtrType(codegen, VarArgsType()); udf_args.push_back(builder.CreateBitCast(varargs_buffer, vararg_type, "varargs")); } @@ -428,119 +432,11 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { return Status::OK(); } -Status ScalarFnCall::GetUdf(LlvmCodeGen* codegen, Function** udf) { - // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd. - // TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd - // code. Always use the interpreted version of these functions. - // TODO: fix these built-in functions so we don't need 'broken_builtin' below. - bool broken_builtin = fn_.name.function_name == "from_utc_timestamp" || - fn_.name.function_name == "to_utc_timestamp" || - fn_.scalar_fn.symbol.find("AddSub") != string::npos; - if (fn_.binary_type == TFunctionBinaryType::NATIVE || - (fn_.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) { - // In this path, we are code that has been statically compiled to assembly. - // This can either be a UDF implemented in a .so or a builtin using the UDF - // interface with the code in impalad. - void* fn_ptr; - Status status = LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.scalar_fn.symbol, &fn_ptr, &cache_entry_); - if (!status.ok() && fn_.binary_type == TFunctionBinaryType::BUILTIN) { - // Builtins symbols should exist unless there is a version mismatch. - status.AddDetail(ErrorMsg(TErrorCode::MISSING_BUILTIN, - fn_.name.function_name, fn_.scalar_fn.symbol).msg()); - } - RETURN_IF_ERROR(status); - DCHECK(fn_ptr != NULL); - - // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument. - // So, the return type is void. - bool is_decimal = type().type == TYPE_DECIMAL; - Type* return_type = is_decimal ? codegen->void_type() : - CodegenAnyVal::GetLoweredType(codegen, type()); - - // Convert UDF function pointer to Function*. Start by creating a function - // prototype for it. - LlvmCodeGen::FnPrototype prototype(codegen, fn_.scalar_fn.symbol, return_type); - - if (is_decimal) { - // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument - Type* output_type = - codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, type())); - prototype.AddArgument("output", output_type); - } - - // The "FunctionContext*" argument. - prototype.AddArgument("ctx", - codegen->GetPtrType("class.impala_udf::FunctionContext")); - - // The "fixed" arguments for the UDF function. - for (int i = 0; i < NumFixedArgs(); ++i) { - stringstream arg_name; - arg_name << "fixed_arg_" << i; - Type* arg_type = codegen->GetPtrType( - CodegenAnyVal::GetUnloweredType(codegen, children_[i]->type())); - prototype.AddArgument(arg_name.str(), arg_type); - } - // The varargs for the UDF function if there is any. - if (NumVarArgs() > 0) { - Type* vararg_type = CodegenAnyVal::GetUnloweredPtrType( - codegen, children_[vararg_start_idx_]->type()); - prototype.AddArgument("num_var_arg", codegen->GetType(TYPE_INT)); - prototype.AddArgument("var_arg", vararg_type); - } - - // Create a Function* with the generated type. This is only a function - // declaration, not a definition, since we do not create any basic blocks or - // instructions in it. - *udf = prototype.GeneratePrototype(NULL, NULL, false); - - // Associate the dynamically loaded function pointer with the Function* we defined. - // This tells LLVM where the compiled function definition is located in memory. - codegen->execution_engine()->addGlobalMapping(*udf, fn_ptr); - } else if (fn_.binary_type == TFunctionBinaryType::BUILTIN) { - // In this path, we're running a builtin with the UDF interface. The IR is - // in the llvm module. - *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false); - if (*udf == NULL) { - // Builtins symbols should exist unless there is a version mismatch. - stringstream ss; - ss << "Builtin '" << fn_.name.function_name << "' with symbol '" - << fn_.scalar_fn.symbol << "' does not exist. " - << "Verify that all your impalads are the same version."; - return Status(ss.str()); - } - // Builtin functions may use Expr::GetConstant(). Clone the function in case we need - // to use it again, and rename it to something more manageable than the mangled name. - string demangled_name = SymbolsUtil::DemangleNoArgs((*udf)->getName().str()); - *udf = codegen->CloneFunction(*udf); - (*udf)->setName(demangled_name); - InlineConstants(codegen, *udf); - *udf = codegen->FinalizeFunction(*udf); - DCHECK(*udf != NULL); - } else { - // We're running an IR UDF. - DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR); - *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false); - if (*udf == NULL) { - stringstream ss; - ss << "Unable to locate function " << fn_.scalar_fn.symbol << " from LLVM module " - << fn_.hdfs_location; - return Status(ss.str()); - } - *udf = codegen->FinalizeFunction(*udf); - if (*udf == NULL) { - return Status( - TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location); - } - } - return Status::OK(); -} - Status ScalarFnCall::GetFunction(LlvmCodeGen* codegen, const string& symbol, void** fn) { - if (fn_.binary_type == TFunctionBinaryType::NATIVE || - fn_.binary_type == TFunctionBinaryType::BUILTIN) { - return LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, symbol, fn, - &cache_entry_); + if (fn_.binary_type == TFunctionBinaryType::NATIVE + || fn_.binary_type == TFunctionBinaryType::BUILTIN) { + return LibCache::instance()->GetSoFunctionPtr( + fn_.hdfs_location, symbol, fn, &cache_entry_); } else { DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR); DCHECK(codegen != NULL); diff --git a/be/src/exprs/scalar-fn-call.h b/be/src/exprs/scalar-fn-call.h index c8bc8c89a..ffb9a2fd5 100644 --- a/be/src/exprs/scalar-fn-call.h +++ b/be/src/exprs/scalar-fn-call.h @@ -48,7 +48,7 @@ class TExprNode; /// - Test cancellation /// - Type descs in UDA test harness /// - Allow more functions to be NULL in UDA test harness -class ScalarFnCall: public Expr { +class ScalarFnCall : public Expr { public: virtual std::string DebugString() const; @@ -117,9 +117,6 @@ class ScalarFnCall: public Expr { return children_.back()->type(); } - /// Loads the native or IR function from HDFS and puts the result in *udf. - Status GetUdf(LlvmCodeGen* codegen, llvm::Function** udf); - /// Loads the native or IR function 'symbol' from HDFS and puts the result in *fn. /// If the function is loaded from an IR module, it cannot be called until the module /// has been JIT'd (i.e. after GetCodegendComputeFn() has been called). diff --git a/be/src/exprs/timestamp-functions.cc b/be/src/exprs/timestamp-functions.cc index a5196961d..a63438c91 100644 --- a/be/src/exprs/timestamp-functions.cc +++ b/be/src/exprs/timestamp-functions.cc @@ -76,7 +76,7 @@ void ThrowIfDateOutOfRange(const boost::gregorian::date& date) { // This function uses inline asm functions, which we believe to be from the boost library. // Inline asm is not currently supported by JIT, so this function should always be run in -// the interpreted mode. This is handled in ScalarFnCall::GetUdf(). +// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction(). TimestampVal TimestampFunctions::FromUtc(FunctionContext* context, const TimestampVal& ts_val, const StringVal& tz_string_val) { if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null(); @@ -114,7 +114,7 @@ TimestampVal TimestampFunctions::FromUtc(FunctionContext* context, // This function uses inline asm functions, which we believe to be from the boost library. // Inline asm is not currently supported by JIT, so this function should always be run in -// the interpreted mode. This is handled in ScalarFnCall::GetUdf(). +// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction(). TimestampVal TimestampFunctions::ToUtc(FunctionContext* context, const TimestampVal& ts_val, const StringVal& tz_string_val) { if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null(); diff --git a/be/src/runtime/types.cc b/be/src/runtime/types.cc index 3a04ca3b5..f580628b1 100644 --- a/be/src/runtime/types.cc +++ b/be/src/runtime/types.cc @@ -310,6 +310,12 @@ string ColumnType::DebugString() const { } } +vector ColumnType::FromThrift(const vector& ttypes) { + vector types; + for (const TColumnType& ttype : ttypes) types.push_back(FromThrift(ttype)); + return types; +} + ostream& operator<<(ostream& os, const ColumnType& type) { os << type.DebugString(); return os; diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index f265705ae..f2db3bd8f 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -147,6 +147,8 @@ struct ColumnType { return result; } + static std::vector FromThrift(const std::vector& ttypes); + bool operator==(const ColumnType& o) const { if (type != o.type) return false; if (children != o.children) return false; diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc index 009750081..549f2f0cf 100644 --- a/be/src/testutil/test-udas.cc +++ b/be/src/testutil/test-udas.cc @@ -17,6 +17,9 @@ #include "testutil/test-udas.h" +#include + +// Don't include Impala internal headers - real UDAs won't include them. #include using namespace impala_udf; @@ -48,19 +51,95 @@ void Agg(FunctionContext*, const StringVal&, const DoubleVal&, StringVal*) {} void AggInit(FunctionContext*, StringVal*){} void AggMerge(FunctionContext*, const StringVal&, StringVal*) {} StringVal AggSerialize(FunctionContext*, const StringVal& v) { return v;} -StringVal AggFinalize(FunctionContext*, const StringVal& v) { return v;} +StringVal AggFinalize(FunctionContext*, const StringVal& v) { + return v; +} - -// Defines AggIntermediate(int) returns BIGINT intermediate CHAR(10) -// TODO: StringVal should be replaced with BufferVal in Impala 2.0 -void AggIntermediate(FunctionContext*, const IntVal&, StringVal*) {} -void AggIntermediateUpdate(FunctionContext*, const IntVal&, StringVal*) {} -void AggIntermediateInit(FunctionContext*, StringVal*) {} -void AggIntermediateMerge(FunctionContext*, const StringVal&, StringVal*) {} -BigIntVal AggIntermediateFinalize(FunctionContext*, const StringVal&) { +// Defines AggIntermediate(int) returns BIGINT intermediate STRING +void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {} +void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +void AggIntermediateInit(FunctionContext* context, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return BigIntVal::null(); } +// Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6) +// intermediate DECIMAL(3,4) +// Useful to test that type parameters are plumbed through. +void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); + return DecimalVal::null(); +} + // Defines MemTest(bigint) return bigint // "Allocates" the specified number of bytes in the update function and frees them in the // serialize function. Useful for testing mem limits. @@ -99,22 +178,57 @@ BigIntVal MemTestFinalize(FunctionContext* context, const BigIntVal& total) { // Defines aggregate function for testing different intermediate/output types that // computes the truncated bigint sum of many floats. void TruncSumInit(FunctionContext* context, DoubleVal* total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); *total = DoubleVal(0); } void TruncSumUpdate(FunctionContext* context, const DoubleVal& val, DoubleVal* total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); total->val += val.val; } void TruncSumMerge(FunctionContext* context, const DoubleVal& src, DoubleVal* dst) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); dst->val += src.val; } const DoubleVal TruncSumSerialize(FunctionContext* context, const DoubleVal& total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return total; } BigIntVal TruncSumFinalize(FunctionContext* context, const DoubleVal& total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return BigIntVal(static_cast(total.val)); } diff --git a/be/src/testutil/test-udas.h b/be/src/testutil/test-udas.h index 8cd38ecc3..57b0f066a 100644 --- a/be/src/testutil/test-udas.h +++ b/be/src/testutil/test-udas.h @@ -18,6 +18,7 @@ #ifndef IMPALA_UDF_TEST_UDAS_H #define IMPALA_UDF_TEST_UDAS_H +// Don't include Impala internal headers - real UDAs won't include them. #include "udf/udf.h" using namespace impala_udf; diff --git a/be/src/udf/udf-internal.h b/be/src/udf/udf-internal.h index bf3032ed8..96d0fc775 100644 --- a/be/src/udf/udf-internal.h +++ b/be/src/udf/udf-internal.h @@ -47,6 +47,12 @@ class RuntimeState; /// This class actually implements the interface of FunctionContext. This is split to /// hide the details from the external header. /// Note: The actual user code does not include this file. +/// +/// Exprs (e.g. UDFs and UDAs) require a FunctionContext to store state related to +/// evaluation of the expression. Each FunctionContext is associated with a backend Expr +/// or AggFnEvaluator, which is derived from a TExprNode generated by the Impala frontend. +/// FunctionContexts are allocated and managed by ExprContext. Exprs shouldn't try to +/// create FunctionContext themselves. class FunctionContextImpl { public: /// Create a FunctionContext for a UDF. Caller is responsible for deleting it. diff --git a/be/src/udf/udf-ir.cc b/be/src/udf/udf-ir.cc index 24773f0e2..c12133c7b 100644 --- a/be/src/udf/udf-ir.cc +++ b/be/src/udf/udf-ir.cc @@ -34,6 +34,10 @@ int FunctionContext::GetNumArgs() const { return impl_->arg_types_.size(); } +const FunctionContext::TypeDesc& FunctionContext::GetIntermediateType() const { + return impl_->intermediate_type_; +} + const FunctionContext::TypeDesc& FunctionContext::GetReturnType() const { return impl_->return_type_; } diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index 461482c17..4fdca2da1 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -189,21 +189,25 @@ class FunctionContext { const TypeDesc& GetIntermediateType() const; /// Returns the number of arguments to this function (not including the FunctionContext* - /// argument). + /// argument or the output of a UDA). + /// For UDAs, returns the number of logical arguments of the aggregate function, not + /// the number of arguments of the C++ function being executed. int GetNumArgs() const; /// Returns the type information for the arg_idx-th argument (0-indexed, not including /// the FunctionContext* argument). Returns NULL if arg_idx is invalid. + /// For UDAs, returns the logical argument types of the aggregate function, not the + /// argument types of the C++ function being executed. const TypeDesc* GetArgType(int arg_idx) const; - /// Returns true if the arg_idx-th input argument (0 indexed, not including the - /// FunctionContext* argument) is a constant (e.g. 5, "string", 1 + 1). + /// Returns true if the arg_idx-th input argument (indexed in the same way as + /// GetArgType()) is a constant (e.g. 5, "string", 1 + 1). bool IsArgConstant(int arg_idx) const; - /// Returns a pointer to the value of the arg_idx-th input argument (0 indexed, not - /// including the FunctionContext* argument). Returns NULL if the argument is not - /// constant. This function can be used to obtain user-specified constants in a UDF's - /// Init() or Close() functions. + /// Returns a pointer to the value of the arg_idx-th input argument (indexed in the + /// same way as GetArgType()). Returns NULL if the argument is not constant. This + /// function can be used to obtain user-specified constants in a UDF's Init() or + /// Close() functions. AnyVal* GetConstantArg(int arg_idx) const; /// TODO: Do we need to add arbitrary key/value metadata. This would be plumbed diff --git a/common/thrift/Exprs.thrift b/common/thrift/Exprs.thrift index 33b859a54..fc0f4ee92 100644 --- a/common/thrift/Exprs.thrift +++ b/common/thrift/Exprs.thrift @@ -112,9 +112,14 @@ struct TStringLiteral { 1: required string value; } +// Additional information for aggregate functions. struct TAggregateExpr { // Indicates whether this expr is the merge() of an aggregation. 1: required bool is_merge_agg + + // The types of the input arguments to the aggregate function. May differ from the + // input expr types if this is the merge() of an aggregation. + 2: required list arg_types; } // This is essentially a union over the subclasses of Expr. diff --git a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java index ec88aaeb6..5dbde32de 100644 --- a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java +++ b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java @@ -673,7 +673,8 @@ public class AggregateInfo extends AggregateInfoBase { * materialized slots of the output tuple corresponds to the number of materialized * aggregate functions plus the number of grouping exprs. Also checks that the return * types of the aggregate and grouping exprs correspond to the slots in the output - * tuple. + * tuple and that the input types stored in the merge aggregation are consistent + * with the input exprs. */ public void checkConsistency() { ArrayList slots = outputTupleDesc_.getSlots(); @@ -707,6 +708,13 @@ public class AggregateInfo extends AggregateInfoBase { slotType.toString())); ++slotIdx; } + if (mergeAggInfo_ != null) { + // Check that the argument types in mergeAggInfo_ are consistent with input exprs. + for (int i = 0; i < aggregateExprs_.size(); ++i) { + FunctionCallExpr mergeAggExpr = mergeAggInfo_.aggregateExprs_.get(i); + mergeAggExpr.validateMergeAggFn(aggregateExprs_.get(i)); + } + } } /** diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java index 15af25f9a..4d7dca835 100644 --- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java +++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java @@ -30,6 +30,7 @@ import org.apache.impala.catalog.Type; import org.apache.impala.common.AnalysisException; import org.apache.impala.common.TreeNode; import org.apache.impala.thrift.TAggregateExpr; +import org.apache.impala.thrift.TColumnType; import org.apache.impala.thrift.TExprNode; import org.apache.impala.thrift.TExprNodeType; import org.apache.impala.thrift.TFunctionBinaryType; @@ -45,10 +46,12 @@ public class FunctionCallExpr extends Expr { private boolean isAnalyticFnCall_ = false; private boolean isInternalFnCall_ = false; - // Indicates whether this is a merge aggregation function that should use the merge - // instead of the update symbol. This flag also affects the behavior of - // resetAnalysisState() which is used during expr substitution. - private final boolean isMergeAggFn_; + // Non-null iff this is an aggregation function that executes the Merge() step. This + // is an analyzed clone of the FunctionCallExpr that executes the Update() function + // feeding into this Merge(). This is stored so that we can access the types of the + // original input argument exprs. Note that the nullness affects the behaviour of + // resetAnalysisState(), which is used during expr substitution. + private final FunctionCallExpr mergeAggInputFn_; // Printed in toSqlImpl(), if set. Used for merge agg fns. private String label_; @@ -62,15 +65,16 @@ public class FunctionCallExpr extends Expr { } public FunctionCallExpr(FunctionName fnName, FunctionParams params) { - this(fnName, params, false); + this(fnName, params, null); } - private FunctionCallExpr( - FunctionName fnName, FunctionParams params, boolean isMergeAggFn) { + private FunctionCallExpr(FunctionName fnName, FunctionParams params, + FunctionCallExpr mergeAggInputFn) { super(); fnName_ = fnName; params_ = params; - isMergeAggFn_ = isMergeAggFn; + mergeAggInputFn_ = + mergeAggInputFn == null ? null : (FunctionCallExpr)mergeAggInputFn.clone(); if (params.exprs() != null) children_ = Lists.newArrayList(params_.exprs()); } @@ -99,12 +103,12 @@ public class FunctionCallExpr extends Expr { Preconditions.checkState(agg.isAnalyzed()); Preconditions.checkState(agg.isAggregateFunction()); FunctionCallExpr result = new FunctionCallExpr( - agg.fnName_, new FunctionParams(false, params), true); + agg.fnName_, new FunctionParams(false, params), agg); // Inherit the function object from 'agg'. result.fn_ = agg.fn_; result.type_ = agg.type_; // Set an explicit label based on the input agg. - if (agg.isMergeAggFn_) { + if (agg.isMergeAggFn()) { result.label_ = agg.label_; } else { // fn(input) becomes fn:merge(input). @@ -123,7 +127,8 @@ public class FunctionCallExpr extends Expr { fnName_ = other.fnName_; isAnalyticFnCall_ = other.isAnalyticFnCall_; isInternalFnCall_ = other.isInternalFnCall_; - isMergeAggFn_ = other.isMergeAggFn_; + mergeAggInputFn_ = + other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); // Clone the params in a way that keeps the children_ and the params.exprs() // in sync. The children have already been cloned in the super c'tor. if (other.params_.isStar()) { @@ -135,7 +140,7 @@ public class FunctionCallExpr extends Expr { label_ = other.label_; } - public boolean isMergeAggFn() { return isMergeAggFn_; } + public boolean isMergeAggFn() { return mergeAggInputFn_ != null; } @Override public void resetAnalysisState() { @@ -144,7 +149,7 @@ public class FunctionCallExpr extends Expr { // intermediate agg type is not the same as the output type. Preserve the original // fn_ such that analyze() hits the special-case code for merge agg fns that // handles this case. - if (!isMergeAggFn_) fn_ = null; + if (!isMergeAggFn()) fn_ = null; } @Override @@ -160,7 +165,7 @@ public class FunctionCallExpr extends Expr { public String toSqlImpl() { if (label_ != null) return label_; // Merge agg fns should have an explicit label. - Preconditions.checkState(!isMergeAggFn_); + Preconditions.checkState(!isMergeAggFn()); StringBuilder sb = new StringBuilder(); sb.append(fnName_).append("("); if (params_.isStar()) sb.append("*"); @@ -226,7 +231,12 @@ public class FunctionCallExpr extends Expr { protected void toThrift(TExprNode msg) { if (isAggregateFunction() || isAnalyticFnCall_) { msg.node_type = TExprNodeType.AGGREGATE_EXPR; - if (!isAnalyticFnCall_) msg.setAgg_expr(new TAggregateExpr(isMergeAggFn_)); + List aggFnArgTypes = Lists.newArrayList(); + FunctionCallExpr inputAggFn = isMergeAggFn() ? mergeAggInputFn_ : this; + for (Expr child: inputAggFn.children_) { + aggFnArgTypes.add(child.getType().toThrift()); + } + msg.setAgg_expr(new TAggregateExpr(isMergeAggFn(), aggFnArgTypes)); } else { msg.node_type = TExprNodeType.FUNCTION_CALL; } @@ -383,7 +393,7 @@ public class FunctionCallExpr extends Expr { protected void analyzeImpl(Analyzer analyzer) throws AnalysisException { fnName_.analyze(analyzer); - if (isMergeAggFn_) { + if (isMergeAggFn()) { // This is the function call expr after splitting up to a merge aggregation. // The function has already been analyzed so just do the minimal sanity // check here. @@ -524,6 +534,25 @@ public class FunctionCallExpr extends Expr { } } + /** + * Validate that the internal state, specifically types, is consistent between the + * the Update() and Merge() aggregate functions. + */ + void validateMergeAggFn(FunctionCallExpr inputAggFn) { + Preconditions.checkState(isMergeAggFn()); + List copiedInputExprs = mergeAggInputFn_.getChildren(); + List inputExprs = inputAggFn.getChildren(); + Preconditions.checkState(copiedInputExprs.size() == inputExprs.size()); + for (int i = 0; i < inputExprs.size(); ++i) { + Type copiedInputType = copiedInputExprs.get(i).getType(); + Type inputType = inputExprs.get(i).getType(); + Preconditions.checkState(copiedInputType.equals(inputType), + String.format("Copied expr %s arg type %s differs from input expr type %s " + + "in original expr %s", toSql(), copiedInputType.toSql(), + inputType.toSql(), inputAggFn.toSql())); + } + } + @Override public Expr clone() { return new FunctionCallExpr(this); } } diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test b/testdata/workloads/functional-query/queries/QueryTest/uda.test index 05e24e71b..3a9bbbec0 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/uda.test +++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test @@ -69,3 +69,22 @@ from functional.alltypesagg ---- TYPES bigint,bigint ==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This relies on asserts in the UDA funciton +select agg_intermediate(int_col), count(*) +from functional.alltypesagg +---- RESULTS +NULL,11000 +---- TYPES +bigint,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This relies on asserts in the UDA funciton +select agg_decimal_intermediate(cast(d1 as decimal(2,1)), 2), count(*) +from functional.decimal_tbl +---- RESULTS +NULL,5 +---- TYPES +decimal,bigint diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py index 746cf8828..56ce233a5 100644 --- a/tests/query_test/test_udfs.py +++ b/tests/query_test/test_udfs.py @@ -93,6 +93,16 @@ update_fn='ToggleNullUpdate' merge_fn='ToggleNullMerge'; create aggregate function {database}.count_nulls(bigint) returns bigint location '{location}' update_fn='CountNullsUpdate' merge_fn='CountNullsMerge'; + +create aggregate function {database}.agg_intermediate(int) +returns bigint intermediate string location '{location}' +init_fn='AggIntermediateInit' update_fn='AggIntermediateUpdate' +merge_fn='AggIntermediateMerge' finalize_fn='AggIntermediateFinalize'; + +create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int) +returns decimal(6,5) intermediate decimal(4,3) location '{location}' +init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate' +merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize'; """ # Create test UDF functions in {database} from library {location}