mirror of
https://github.com/apache/impala.git
synced 2026-01-06 15:01:43 -05:00
IMPALA-1430,IMPALA-4878,IMPALA-4879: codegen native UDAs
This uses the existing infrastructure for codegening builtin UDAs and for codegening calls to UDFs. GetUdf() is refactored to support both UDFs and UDAs. IR UDAs are still not allowed by the frontend. It's unclear if we want to enable them going forward because of the difficulties in testing and supporting IR UDFs/UDAs. This also fixes some bugs with the Get*Type() methods of FunctionContext. GetArgType() and related methods now always return the logical input types of the aggregate function. Getting the tests to pass required fixing IMPALA-4878 because they called GetIntermediateType(). Testing: test_udfs.py tests UDAs with codegen enabled and disabled. Added some asserts to test UDAs to check that the correct types are passed in via the FunctionContext. Change-Id: Id1708eaa96eb76fb9bec5eeabf209f81c88eec2f Reviewed-on: http://gerrit.cloudera.org:8080/5161 Reviewed-by: Dan Hecht <dhecht@cloudera.com> Tested-by: Impala Public Jenkins
This commit is contained in:
committed by
Impala Public Jenkins
parent
1335af3684
commit
d2d3f4c1a6
@@ -67,7 +67,7 @@ Type* CodegenAnyVal::GetLoweredType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
}
|
||||
}
|
||||
|
||||
Type* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
PointerType* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
return GetLoweredType(cg, type)->getPointerTo();
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ Type* CodegenAnyVal::GetUnloweredType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
return result;
|
||||
}
|
||||
|
||||
Type* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
PointerType* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) {
|
||||
return GetUnloweredType(cg, type)->getPointerTo();
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ class CodegenAnyVal {
|
||||
|
||||
/// Returns the lowered AnyVal pointer type associated with 'type'.
|
||||
/// E.g.: TYPE_BOOLEAN => i16*
|
||||
static llvm::Type* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type);
|
||||
static llvm::PointerType* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type);
|
||||
|
||||
/// Returns the unlowered AnyVal type associated with 'type'.
|
||||
/// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"
|
||||
@@ -103,7 +103,7 @@ class CodegenAnyVal {
|
||||
|
||||
/// Returns the unlowered AnyVal pointer type associated with 'type'.
|
||||
/// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"*
|
||||
static llvm::Type* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type);
|
||||
static llvm::PointerType* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type);
|
||||
|
||||
/// Return the constant type-lowered value corresponding to a null *Val.
|
||||
/// E.g.: passing TYPE_DOUBLE (corresponding to the lowered DoubleVal { i8, double })
|
||||
|
||||
@@ -70,6 +70,7 @@
|
||||
#include "util/hdfs-util.h"
|
||||
#include "util/path-builder.h"
|
||||
#include "util/runtime-profile-counters.h"
|
||||
#include "util/symbols-util.h"
|
||||
#include "util/test-info.h"
|
||||
|
||||
#include "common/names.h"
|
||||
@@ -766,8 +767,116 @@ Function* LlvmCodeGen::FnPrototype::GeneratePrototype(
|
||||
return fn;
|
||||
}
|
||||
|
||||
int LlvmCodeGen::ReplaceCallSites(Function* caller, Function* new_fn,
|
||||
const string& target_name) {
|
||||
Status LlvmCodeGen::LoadFunction(const TFunction& fn, const std::string& symbol,
|
||||
const ColumnType* return_type, const std::vector<ColumnType>& arg_types,
|
||||
int num_fixed_args, bool has_varargs, Function** llvm_fn,
|
||||
LibCacheEntry** cache_entry) {
|
||||
DCHECK_GE(arg_types.size(), num_fixed_args);
|
||||
DCHECK(has_varargs || arg_types.size() == num_fixed_args);
|
||||
DCHECK(!has_varargs || arg_types.size() > num_fixed_args);
|
||||
// from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd.
|
||||
// TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd
|
||||
// code. Always use the interpreted version of these functions.
|
||||
// TODO: fix these built-in functions so we don't need 'broken_builtin' below.
|
||||
bool broken_builtin = fn.name.function_name == "from_utc_timestamp"
|
||||
|| fn.name.function_name == "to_utc_timestamp"
|
||||
|| symbol.find("AddSub") != string::npos;
|
||||
if (fn.binary_type == TFunctionBinaryType::NATIVE
|
||||
|| (fn.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) {
|
||||
// In this path, we are calling a precompiled native function, either a UDF
|
||||
// in a .so or a builtin using the UDF interface.
|
||||
void* fn_ptr;
|
||||
Status status = LibCache::instance()->GetSoFunctionPtr(
|
||||
fn.hdfs_location, symbol, &fn_ptr, cache_entry);
|
||||
if (!status.ok() && fn.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
// Builtins symbols should exist unless there is a version mismatch.
|
||||
status.AddDetail(
|
||||
ErrorMsg(TErrorCode::MISSING_BUILTIN, fn.name.function_name, symbol).msg());
|
||||
}
|
||||
RETURN_IF_ERROR(status);
|
||||
DCHECK(fn_ptr != NULL);
|
||||
|
||||
// Per the x64 ABI, DecimalVals are returned via a DecimalVal* output argument.
|
||||
// So, the return type is void.
|
||||
bool is_decimal = return_type != NULL && return_type->type == TYPE_DECIMAL;
|
||||
Type* llvm_return_type = return_type == NULL || is_decimal ?
|
||||
void_type() :
|
||||
CodegenAnyVal::GetLoweredType(this, *return_type);
|
||||
|
||||
// Convert UDF function pointer to Function*. Start by creating a function
|
||||
// prototype for it.
|
||||
FnPrototype prototype(this, symbol, llvm_return_type);
|
||||
|
||||
if (is_decimal) {
|
||||
// Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument
|
||||
Type* output_type = CodegenAnyVal::GetUnloweredPtrType(this, *return_type);
|
||||
prototype.AddArgument("output", output_type);
|
||||
}
|
||||
|
||||
// The "FunctionContext*" argument.
|
||||
prototype.AddArgument("ctx", GetPtrType("class.impala_udf::FunctionContext"));
|
||||
|
||||
// The "fixed" arguments for the UDF function, followed by the variable arguments,
|
||||
// if any.
|
||||
for (int i = 0; i < num_fixed_args; ++i) {
|
||||
Type* arg_type = CodegenAnyVal::GetUnloweredPtrType(this, arg_types[i]);
|
||||
prototype.AddArgument(Substitute("fixed_arg_$0", i), arg_type);
|
||||
}
|
||||
|
||||
if (has_varargs) {
|
||||
prototype.AddArgument("num_var_arg", GetType(TYPE_INT));
|
||||
// Get the vararg type from the first vararg.
|
||||
prototype.AddArgument(
|
||||
"var_arg", CodegenAnyVal::GetUnloweredPtrType(this, arg_types[num_fixed_args]));
|
||||
}
|
||||
|
||||
// Create a Function* with the generated type. This is only a function
|
||||
// declaration, not a definition, since we do not create any basic blocks or
|
||||
// instructions in it.
|
||||
*llvm_fn = prototype.GeneratePrototype(NULL, NULL, false);
|
||||
|
||||
// Associate the dynamically loaded function pointer with the Function* we defined.
|
||||
// This tells LLVM where the compiled function definition is located in memory.
|
||||
execution_engine_->addGlobalMapping(*llvm_fn, fn_ptr);
|
||||
} else if (fn.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
// In this path, we're running a builtin with the UDF interface. The IR is
|
||||
// in the llvm module. Builtin functions may use Expr::GetConstant(). Clone the
|
||||
// function so that we can replace constants in the copied function.
|
||||
*llvm_fn = GetFunction(symbol, true);
|
||||
if (*llvm_fn == NULL) {
|
||||
// Builtins symbols should exist unless there is a version mismatch.
|
||||
return Status(Substitute("Builtin '$0' with symbol '$1' does not exist. Verify "
|
||||
"that all your impalads are the same version.",
|
||||
fn.name.function_name, symbol));
|
||||
}
|
||||
// Rename the function to something more readable than the mangled name.
|
||||
string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str());
|
||||
(*llvm_fn)->setName(demangled_name);
|
||||
} else {
|
||||
// We're running an IR UDF.
|
||||
DCHECK_EQ(fn.binary_type, TFunctionBinaryType::IR);
|
||||
|
||||
string local_path;
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath(
|
||||
fn.hdfs_location, LibCache::TYPE_IR, &local_path));
|
||||
// Link the UDF module into this query's main module so the UDF's functions are
|
||||
// available in the main module.
|
||||
RETURN_IF_ERROR(LinkModule(local_path));
|
||||
|
||||
*llvm_fn = GetFunction(symbol, true);
|
||||
if (*llvm_fn == NULL) {
|
||||
return Status(Substitute("Unable to load function '$0' from LLVM module '$1'",
|
||||
symbol, fn.hdfs_location));
|
||||
}
|
||||
// Rename the function to something more readable than the mangled name.
|
||||
string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str());
|
||||
(*llvm_fn)->setName(demangled_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int LlvmCodeGen::ReplaceCallSites(
|
||||
Function* caller, Function* new_fn, const string& target_name) {
|
||||
DCHECK(!is_compiled_);
|
||||
DCHECK(caller->getParent() == module_);
|
||||
DCHECK(caller != NULL);
|
||||
|
||||
@@ -294,6 +294,23 @@ class LlvmCodeGen {
|
||||
/// functions.
|
||||
Status FinalizeModule();
|
||||
|
||||
/// Loads a native or IR function 'fn' with symbol 'symbol' from the builtins or
|
||||
/// an external library and puts the result in *llvm_fn. *llvm_fn can be safely
|
||||
/// modified in place, because it is either newly generated or cloned. The caller must
|
||||
/// call FinalizeFunction() on 'llvm_fn' once it is done modifying it. The function has
|
||||
/// return type 'return_type' (void if 'return_type' is NULL) and input argument types
|
||||
/// 'arg_types'. The first 'num_fixed_args' arguments are fixed arguments, and the
|
||||
/// remaining arguments are varargs. 'has_varargs' indicates whether the function
|
||||
/// accepts varargs. If 'has_varargs' is true, there must be at least one vararg. If
|
||||
/// the function is loaded from a library, 'cache_entry' is updated to point to the
|
||||
/// library containing the function. If 'cache_entry' is set to a non-NULL value by
|
||||
/// this function, the caller must call LibCache::DecrementUseCount() on it when done
|
||||
/// using the function.
|
||||
Status LoadFunction(const TFunction& fn, const std::string& symbol,
|
||||
const ColumnType* return_type, const std::vector<ColumnType>& arg_types,
|
||||
int num_fixed_args, bool has_varargs, llvm::Function** llvm_fn,
|
||||
LibCacheEntry** cache_entry);
|
||||
|
||||
/// Replaces all instructions in 'caller' that call 'target_name' with a call
|
||||
/// instruction to 'new_fn'. Returns the number of call sites updated.
|
||||
///
|
||||
@@ -485,10 +502,6 @@ class LlvmCodeGen {
|
||||
llvm::Value* CodegenArrayAt(
|
||||
LlvmBuilder*, llvm::Value* array, int idx, const char* name = "");
|
||||
|
||||
/// Loads a module at 'file' and links it to the module associated with
|
||||
/// this LlvmCodeGen object. The module must be on the local filesystem.
|
||||
Status LinkModule(const std::string& file);
|
||||
|
||||
/// If there are more than this number of expr trees (or functions that evaluate
|
||||
/// expressions), avoid inlining avoid inlining for the exprs exceeding this threshold.
|
||||
static const int CODEGEN_INLINE_EXPRS_THRESHOLD = 100;
|
||||
@@ -538,6 +551,10 @@ class LlvmCodeGen {
|
||||
Status LoadModuleFromMemory(std::unique_ptr<llvm::MemoryBuffer> module_ir_buf,
|
||||
std::string module_name, std::unique_ptr<llvm::Module>* module);
|
||||
|
||||
/// Loads a module at 'file' and links it to the module associated with
|
||||
/// this LlvmCodeGen object. The module must be on the local filesystem.
|
||||
Status LinkModule(const std::string& file);
|
||||
|
||||
/// Strip global constructors and destructors from an LLVM module. We never run them
|
||||
/// anyway (they must be explicitly invoked) so it is dead code.
|
||||
static void StripGlobalCtorsDtors(llvm::Module* module);
|
||||
|
||||
@@ -1721,22 +1721,8 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen,
|
||||
const vector<CodegenAnyVal>& input_vals, const CodegenAnyVal& dst,
|
||||
CodegenAnyVal* updated_dst_val) {
|
||||
DCHECK_EQ(evaluator->input_expr_ctxs().size(), input_vals.size());
|
||||
const string& symbol =
|
||||
evaluator->is_merge() ? evaluator->merge_symbol() : evaluator->update_symbol();
|
||||
const ColumnType& dst_type = evaluator->intermediate_type();
|
||||
|
||||
// TODO: to support actual UDAs, not just builtin functions using the UDA interface,
|
||||
// we need to load the function at this point.
|
||||
Function* uda_fn = codegen->GetFunction(symbol, true);
|
||||
DCHECK(uda_fn != NULL);
|
||||
|
||||
vector<FunctionContext::TypeDesc> arg_types;
|
||||
for (int i = 0; i < evaluator->input_expr_ctxs().size(); ++i) {
|
||||
arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(
|
||||
evaluator->input_expr_ctxs()[i]->root()->type()));
|
||||
}
|
||||
Expr::InlineConstants(
|
||||
AnyValUtil::ColumnTypeToTypeDesc(dst_type), arg_types, codegen, uda_fn);
|
||||
Function* uda_fn;
|
||||
RETURN_IF_ERROR(evaluator->GetUpdateOrMergeFunction(codegen, &uda_fn));
|
||||
|
||||
// Set up arguments for call to UDA, which are the FunctionContext*, followed by
|
||||
// pointers to all input values, followed by a pointer to the destination value.
|
||||
@@ -1753,6 +1739,7 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen,
|
||||
// Create pointer to dst to pass to uda_fn. We must use the unlowered type for the
|
||||
// same reason as above.
|
||||
Value* dst_lowered_ptr = dst.GetLoweredPtr("dst_lowered_ptr");
|
||||
const ColumnType& dst_type = evaluator->intermediate_type();
|
||||
Type* dst_unlowered_ptr_type = CodegenAnyVal::GetUnloweredPtrType(codegen, dst_type);
|
||||
Value* dst_unlowered_ptr = builder->CreateBitCast(
|
||||
dst_lowered_ptr, dst_unlowered_ptr_type, "dst_unlowered_ptr");
|
||||
@@ -1825,13 +1812,6 @@ Status PartitionedAggregationNode::CodegenUpdateTuple(
|
||||
"intermediate tuple desc");
|
||||
}
|
||||
|
||||
for (AggFnEvaluator* evaluator : aggregate_evaluators_) {
|
||||
// Don't codegen things that aren't builtins (for now)
|
||||
if (!evaluator->is_builtin()) {
|
||||
return Status("PartitionedAggregationNode::CodegenUpdateTuple(): UDA codegen NYI");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the types to match the UpdateTuple signature
|
||||
Type* agg_node_type = codegen->GetType(PartitionedAggregationNode::LLVM_CLASS_NAME);
|
||||
Type* fn_ctx_type = codegen->GetType(FunctionContextImpl::LLVM_FUNCTIONCONTEXT_NAME);
|
||||
|
||||
@@ -23,8 +23,9 @@
|
||||
#include "common/logging.h"
|
||||
#include "exec/aggregation-node.h"
|
||||
#include "exprs/aggregate-functions.h"
|
||||
#include "exprs/expr-context.h"
|
||||
#include "exprs/anyval-util.h"
|
||||
#include "exprs/expr-context.h"
|
||||
#include "exprs/scalar-fn-call.h"
|
||||
#include "runtime/lib-cache.h"
|
||||
#include "runtime/raw-value.h"
|
||||
#include "runtime/runtime-state.h"
|
||||
@@ -94,6 +95,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn)
|
||||
is_analytic_fn_(is_analytic_fn),
|
||||
intermediate_slot_desc_(NULL),
|
||||
output_slot_desc_(NULL),
|
||||
arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs(
|
||||
ColumnType::FromThrift(desc.agg_expr.arg_types))),
|
||||
cache_entry_(NULL),
|
||||
init_fn_(NULL),
|
||||
update_fn_(NULL),
|
||||
@@ -198,28 +201,15 @@ Status AggFnEvaluator::Prepare(RuntimeState* state, const RowDescriptor& desc,
|
||||
&cache_entry_));
|
||||
}
|
||||
if (!fn_.aggregate_fn.remove_fn_symbol.empty()) {
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
|
||||
fn_.hdfs_location, fn_.aggregate_fn.remove_fn_symbol, &remove_fn_,
|
||||
&cache_entry_));
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
|
||||
fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, &cache_entry_));
|
||||
}
|
||||
if (!fn_.aggregate_fn.finalize_fn_symbol.empty()) {
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
|
||||
fn_.hdfs_location, fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_,
|
||||
&cache_entry_));
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
|
||||
fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, &cache_entry_));
|
||||
}
|
||||
|
||||
vector<FunctionContext::TypeDesc> arg_types;
|
||||
for (int i = 0; i < input_expr_ctxs_.size(); ++i) {
|
||||
arg_types.push_back(
|
||||
AnyValUtil::ColumnTypeToTypeDesc(input_expr_ctxs_[i]->root()->type()));
|
||||
}
|
||||
|
||||
FunctionContext::TypeDesc intermediate_type =
|
||||
AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type());
|
||||
FunctionContext::TypeDesc output_type =
|
||||
AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type());
|
||||
*agg_fn_ctx = FunctionContextImpl::CreateContext(
|
||||
state, agg_fn_pool, intermediate_type, output_type, arg_types);
|
||||
*agg_fn_ctx = FunctionContextImpl::CreateContext(state, agg_fn_pool,
|
||||
GetIntermediateTypeDesc(), GetOutputTypeDesc(), arg_type_descs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@@ -521,6 +511,40 @@ void AggFnEvaluator::SerializeOrFinalize(FunctionContext* agg_fn_ctx, Tuple* src
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the update or merge function for this UDA.
|
||||
Status AggFnEvaluator::GetUpdateOrMergeFunction(LlvmCodeGen* codegen, Function** uda_fn) {
|
||||
const string& symbol =
|
||||
is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : fn_.aggregate_fn.update_fn_symbol;
|
||||
vector<ColumnType> fn_arg_types;
|
||||
for (ExprContext* input_expr_ctx : input_expr_ctxs_) {
|
||||
fn_arg_types.push_back(input_expr_ctx->root()->type());
|
||||
}
|
||||
// The intermediate value is passed as the last argument.
|
||||
fn_arg_types.push_back(intermediate_type());
|
||||
RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, NULL, fn_arg_types,
|
||||
fn_arg_types.size(), false, uda_fn, &cache_entry_));
|
||||
|
||||
// Inline constants into the function body (if there is an IR body).
|
||||
if (!(*uda_fn)->isDeclaration()) {
|
||||
// TODO: IMPALA-4785: we should also replace references to GetIntermediateType()
|
||||
// with constants.
|
||||
Expr::InlineConstants(GetOutputTypeDesc(), arg_type_descs_, codegen, *uda_fn);
|
||||
*uda_fn = codegen->FinalizeFunction(*uda_fn);
|
||||
if (*uda_fn == NULL) {
|
||||
return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FunctionContext::TypeDesc AggFnEvaluator::GetIntermediateTypeDesc() const {
|
||||
return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type());
|
||||
}
|
||||
|
||||
FunctionContext::TypeDesc AggFnEvaluator::GetOutputTypeDesc() const {
|
||||
return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type());
|
||||
}
|
||||
|
||||
string AggFnEvaluator::DebugString(const vector<AggFnEvaluator*>& exprs) {
|
||||
stringstream out;
|
||||
out << "[";
|
||||
|
||||
@@ -119,8 +119,6 @@ class AggFnEvaluator {
|
||||
bool SupportsRemove() const { return remove_fn_ != NULL; }
|
||||
bool SupportsSerialize() const { return serialize_fn_ != NULL; }
|
||||
const std::string& fn_name() const { return fn_.name.function_name; }
|
||||
const std::string& update_symbol() const { return fn_.aggregate_fn.update_fn_symbol; }
|
||||
const std::string& merge_symbol() const { return fn_.aggregate_fn.merge_fn_symbol; }
|
||||
const SlotDescriptor* output_slot_desc() const { return output_slot_desc_; }
|
||||
|
||||
static std::string DebugString(const std::vector<AggFnEvaluator*>& exprs);
|
||||
@@ -168,14 +166,8 @@ class AggFnEvaluator {
|
||||
static void Finalize(const std::vector<AggFnEvaluator*>& evaluators,
|
||||
const std::vector<FunctionContext*>& fn_ctxs, Tuple* src, Tuple* dst);
|
||||
|
||||
/// TODO: implement codegen path. These functions would return IR functions with
|
||||
/// the same signature as the interpreted ones above.
|
||||
/// Function* GetIrInitFn();
|
||||
/// Function* GetIrAddFn();
|
||||
/// Function* GetIrRemoveFn();
|
||||
/// Function* GetIrSerializeFn();
|
||||
/// Function* GetIrGetValueFn();
|
||||
/// Function* GetIrFinalizeFn();
|
||||
/// Gets the codegened update or merge function for this aggregate function.
|
||||
Status GetUpdateOrMergeFunction(LlvmCodeGen* codegen, llvm::Function** uda_fn);
|
||||
|
||||
private:
|
||||
const TFunction fn_;
|
||||
@@ -195,6 +187,9 @@ class AggFnEvaluator {
|
||||
/// expression (e.g. count(*)).
|
||||
std::vector<ExprContext*> input_expr_ctxs_;
|
||||
|
||||
/// The types of the arguments to the aggregate function.
|
||||
const std::vector<FunctionContext::TypeDesc> arg_type_descs_;
|
||||
|
||||
/// The enum for some of the builtins that still require special cased logic.
|
||||
AggregationOp agg_op_;
|
||||
|
||||
@@ -221,6 +216,12 @@ class AggFnEvaluator {
|
||||
/// Use Create() instead.
|
||||
AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn);
|
||||
|
||||
/// Return the intermediate type of the aggregate function.
|
||||
FunctionContext::TypeDesc GetIntermediateTypeDesc() const;
|
||||
|
||||
/// Return the output type of the aggregate function.
|
||||
FunctionContext::TypeDesc GetOutputTypeDesc() const;
|
||||
|
||||
/// TODO: these functions below are not extensible and we need to use codegen to
|
||||
/// generate the calls into the UDA functions (like for UDFs).
|
||||
/// Remove these functions when this is supported.
|
||||
|
||||
@@ -92,17 +92,33 @@ FunctionContext::TypeDesc AnyValUtil::ColumnTypeToTypeDesc(const ColumnType& typ
|
||||
return out;
|
||||
}
|
||||
|
||||
vector<FunctionContext::TypeDesc> AnyValUtil::ColumnTypesToTypeDescs(
|
||||
const vector<ColumnType>& types) {
|
||||
vector<FunctionContext::TypeDesc> type_descs;
|
||||
for (const ColumnType& type : types) type_descs.push_back(ColumnTypeToTypeDesc(type));
|
||||
return type_descs;
|
||||
}
|
||||
|
||||
ColumnType AnyValUtil::TypeDescToColumnType(const FunctionContext::TypeDesc& type) {
|
||||
switch (type.type) {
|
||||
case FunctionContext::TYPE_BOOLEAN: return ColumnType(TYPE_BOOLEAN);
|
||||
case FunctionContext::TYPE_TINYINT: return ColumnType(TYPE_TINYINT);
|
||||
case FunctionContext::TYPE_SMALLINT: return ColumnType(TYPE_SMALLINT);
|
||||
case FunctionContext::TYPE_INT: return ColumnType(TYPE_INT);
|
||||
case FunctionContext::TYPE_BIGINT: return ColumnType(TYPE_BIGINT);
|
||||
case FunctionContext::TYPE_FLOAT: return ColumnType(TYPE_FLOAT);
|
||||
case FunctionContext::TYPE_DOUBLE: return ColumnType(TYPE_DOUBLE);
|
||||
case FunctionContext::TYPE_TIMESTAMP: return ColumnType(TYPE_TIMESTAMP);
|
||||
case FunctionContext::TYPE_STRING: return ColumnType(TYPE_STRING);
|
||||
case FunctionContext::TYPE_BOOLEAN:
|
||||
return ColumnType(TYPE_BOOLEAN);
|
||||
case FunctionContext::TYPE_TINYINT:
|
||||
return ColumnType(TYPE_TINYINT);
|
||||
case FunctionContext::TYPE_SMALLINT:
|
||||
return ColumnType(TYPE_SMALLINT);
|
||||
case FunctionContext::TYPE_INT:
|
||||
return ColumnType(TYPE_INT);
|
||||
case FunctionContext::TYPE_BIGINT:
|
||||
return ColumnType(TYPE_BIGINT);
|
||||
case FunctionContext::TYPE_FLOAT:
|
||||
return ColumnType(TYPE_FLOAT);
|
||||
case FunctionContext::TYPE_DOUBLE:
|
||||
return ColumnType(TYPE_DOUBLE);
|
||||
case FunctionContext::TYPE_TIMESTAMP:
|
||||
return ColumnType(TYPE_TIMESTAMP);
|
||||
case FunctionContext::TYPE_STRING:
|
||||
return ColumnType(TYPE_STRING);
|
||||
case FunctionContext::TYPE_DECIMAL:
|
||||
return ColumnType::CreateDecimalType(type.precision, type.scale);
|
||||
case FunctionContext::TYPE_FIXED_BUFFER:
|
||||
|
||||
@@ -227,6 +227,8 @@ class AnyValUtil {
|
||||
}
|
||||
|
||||
static FunctionContext::TypeDesc ColumnTypeToTypeDesc(const ColumnType& type);
|
||||
static std::vector<FunctionContext::TypeDesc> ColumnTypesToTypeDescs(
|
||||
const std::vector<ColumnType>& types);
|
||||
// Note: constructing a ColumnType is expensive and should be avoided in query execution
|
||||
// paths (i.e. non-setup paths).
|
||||
static ColumnType TypeDescToColumnType(const FunctionContext::TypeDesc& type);
|
||||
|
||||
@@ -637,10 +637,10 @@ int Expr::InlineConstants(LlvmCodeGen* codegen, Function* fn) {
|
||||
}
|
||||
|
||||
int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type,
|
||||
const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen,
|
||||
Function* fn) {
|
||||
const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen,
|
||||
Function* fn) {
|
||||
int replaced = 0;
|
||||
for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end; ) {
|
||||
for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end;) {
|
||||
// Increment iter now so we don't mess it up modifying the instruction below
|
||||
Instruction* instr = &*(iter++);
|
||||
|
||||
@@ -666,7 +666,7 @@ int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type,
|
||||
int i_val = static_cast<int>(i_arg->getSExtValue());
|
||||
// All supported constants are currently integers.
|
||||
call_instr->replaceAllUsesWith(ConstantInt::get(codegen->GetType(TYPE_INT),
|
||||
GetConstantInt(return_type, arg_types, c_val, i_val)));
|
||||
GetConstantInt(return_type, arg_types, c_val, i_val)));
|
||||
call_instr->eraseFromParent();
|
||||
++replaced;
|
||||
}
|
||||
|
||||
@@ -263,9 +263,11 @@ class Expr {
|
||||
// Any additions to this enum must be reflected in both GetConstant*() and
|
||||
// GetIrConstant().
|
||||
enum ExprConstant {
|
||||
// RETURN_TYPE_*: properties of FunctionContext::GetReturnType().
|
||||
RETURN_TYPE_SIZE, // int
|
||||
RETURN_TYPE_PRECISION, // int
|
||||
RETURN_TYPE_SCALE, // int
|
||||
// ARG_TYPE_* with parameter i: properties of FunctionContext::GetArgType(i).
|
||||
ARG_TYPE_SIZE, // int[]
|
||||
ARG_TYPE_PRECISION, // int[]
|
||||
ARG_TYPE_SCALE, // int[]
|
||||
@@ -289,7 +291,8 @@ class Expr {
|
||||
// constants to be replaced must be inlined into the function that InlineConstants()
|
||||
// is run on (e.g. by annotating them with IR_ALWAYS_INLINE).
|
||||
//
|
||||
// TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in loops
|
||||
// TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in
|
||||
// loops
|
||||
static int GetConstantInt(const FunctionContext& ctx, ExprConstant c, int i = -1);
|
||||
|
||||
/// Finds all calls to Expr::GetConstantInt() in 'fn' and replaces them with the
|
||||
@@ -298,8 +301,8 @@ class Expr {
|
||||
/// 'arg_types' are the argument types of the UDF or UDAF, i.e. the values of
|
||||
/// FunctionContext::GetArgType().
|
||||
static int InlineConstants(const FunctionContext::TypeDesc& return_type,
|
||||
const std::vector<FunctionContext::TypeDesc>& arg_types,
|
||||
LlvmCodeGen* codegen, llvm::Function* fn);
|
||||
const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen,
|
||||
llvm::Function* fn);
|
||||
|
||||
static const char* LLVM_CLASS_NAME;
|
||||
|
||||
|
||||
@@ -19,12 +19,12 @@
|
||||
|
||||
#include <vector>
|
||||
#include <gutil/strings/substitute.h>
|
||||
#include <llvm/IR/Attributes.h>
|
||||
#include <llvm/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <llvm/IR/Attributes.h>
|
||||
|
||||
#include <boost/preprocessor/punctuation/comma_if.hpp>
|
||||
#include <boost/preprocessor/repetition/repeat.hpp>
|
||||
#include <boost/preprocessor/repetition/enum_params.hpp>
|
||||
#include <boost/preprocessor/repetition/repeat.hpp>
|
||||
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
|
||||
|
||||
#include "codegen/codegen-anyval.h"
|
||||
@@ -37,8 +37,6 @@
|
||||
#include "runtime/types.h"
|
||||
#include "udf/udf-internal.h"
|
||||
#include "util/debug-util.h"
|
||||
#include "util/dynamic-util.h"
|
||||
#include "util/symbols-util.h"
|
||||
|
||||
#include "common/names.h"
|
||||
|
||||
@@ -311,20 +309,27 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) {
|
||||
}
|
||||
}
|
||||
|
||||
if (fn_.binary_type == TFunctionBinaryType::IR) {
|
||||
string local_path;
|
||||
RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath(
|
||||
fn_.hdfs_location, LibCache::TYPE_IR, &local_path));
|
||||
// Link the UDF module into this query's main module (essentially copy the UDF
|
||||
// module into the main module) so the UDF's functions are available in the main
|
||||
// module.
|
||||
RETURN_IF_ERROR(codegen->LinkModule(local_path));
|
||||
// Load the Prepare() and Close() functions from the LLVM module.
|
||||
RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen));
|
||||
vector<ColumnType> arg_types;
|
||||
for (const Expr* child : children_) arg_types.push_back(child->type());
|
||||
Function* udf;
|
||||
RETURN_IF_ERROR(codegen->LoadFunction(fn_, fn_.scalar_fn.symbol, &type_, arg_types,
|
||||
NumFixedArgs(), vararg_start_idx_ != -1, &udf, &cache_entry_));
|
||||
// Inline constants into the function if it has an IR body.
|
||||
if (!udf->isDeclaration()) {
|
||||
InlineConstants(AnyValUtil::ColumnTypeToTypeDesc(type_),
|
||||
AnyValUtil::ColumnTypesToTypeDescs(arg_types), codegen, udf);
|
||||
udf = codegen->FinalizeFunction(udf);
|
||||
if (udf == NULL) {
|
||||
return Status(
|
||||
TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location);
|
||||
}
|
||||
}
|
||||
|
||||
Function* udf;
|
||||
RETURN_IF_ERROR(GetUdf(codegen, &udf));
|
||||
if (fn_.binary_type == TFunctionBinaryType::IR) {
|
||||
// LoadFunction() should have linked the IR module into 'codegen'. Now load the
|
||||
// Prepare() and Close() functions from 'codegen'.
|
||||
RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen));
|
||||
}
|
||||
|
||||
// Create wrapper that computes args and calls UDF
|
||||
stringstream fn_name;
|
||||
@@ -407,8 +412,7 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) {
|
||||
// Add the number of varargs
|
||||
udf_args.push_back(codegen->GetIntConstant(TYPE_INT, NumVarArgs()));
|
||||
// Add all the accumulated vararg inputs as one input argument.
|
||||
PointerType* vararg_type =
|
||||
codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, VarArgsType()));
|
||||
PointerType* vararg_type = CodegenAnyVal::GetUnloweredPtrType(codegen, VarArgsType());
|
||||
udf_args.push_back(builder.CreateBitCast(varargs_buffer, vararg_type, "varargs"));
|
||||
}
|
||||
|
||||
@@ -428,119 +432,11 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScalarFnCall::GetUdf(LlvmCodeGen* codegen, Function** udf) {
|
||||
// from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd.
|
||||
// TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd
|
||||
// code. Always use the interpreted version of these functions.
|
||||
// TODO: fix these built-in functions so we don't need 'broken_builtin' below.
|
||||
bool broken_builtin = fn_.name.function_name == "from_utc_timestamp" ||
|
||||
fn_.name.function_name == "to_utc_timestamp" ||
|
||||
fn_.scalar_fn.symbol.find("AddSub") != string::npos;
|
||||
if (fn_.binary_type == TFunctionBinaryType::NATIVE ||
|
||||
(fn_.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) {
|
||||
// In this path, we are code that has been statically compiled to assembly.
|
||||
// This can either be a UDF implemented in a .so or a builtin using the UDF
|
||||
// interface with the code in impalad.
|
||||
void* fn_ptr;
|
||||
Status status = LibCache::instance()->GetSoFunctionPtr(
|
||||
fn_.hdfs_location, fn_.scalar_fn.symbol, &fn_ptr, &cache_entry_);
|
||||
if (!status.ok() && fn_.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
// Builtins symbols should exist unless there is a version mismatch.
|
||||
status.AddDetail(ErrorMsg(TErrorCode::MISSING_BUILTIN,
|
||||
fn_.name.function_name, fn_.scalar_fn.symbol).msg());
|
||||
}
|
||||
RETURN_IF_ERROR(status);
|
||||
DCHECK(fn_ptr != NULL);
|
||||
|
||||
// Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument.
|
||||
// So, the return type is void.
|
||||
bool is_decimal = type().type == TYPE_DECIMAL;
|
||||
Type* return_type = is_decimal ? codegen->void_type() :
|
||||
CodegenAnyVal::GetLoweredType(codegen, type());
|
||||
|
||||
// Convert UDF function pointer to Function*. Start by creating a function
|
||||
// prototype for it.
|
||||
LlvmCodeGen::FnPrototype prototype(codegen, fn_.scalar_fn.symbol, return_type);
|
||||
|
||||
if (is_decimal) {
|
||||
// Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument
|
||||
Type* output_type =
|
||||
codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, type()));
|
||||
prototype.AddArgument("output", output_type);
|
||||
}
|
||||
|
||||
// The "FunctionContext*" argument.
|
||||
prototype.AddArgument("ctx",
|
||||
codegen->GetPtrType("class.impala_udf::FunctionContext"));
|
||||
|
||||
// The "fixed" arguments for the UDF function.
|
||||
for (int i = 0; i < NumFixedArgs(); ++i) {
|
||||
stringstream arg_name;
|
||||
arg_name << "fixed_arg_" << i;
|
||||
Type* arg_type = codegen->GetPtrType(
|
||||
CodegenAnyVal::GetUnloweredType(codegen, children_[i]->type()));
|
||||
prototype.AddArgument(arg_name.str(), arg_type);
|
||||
}
|
||||
// The varargs for the UDF function if there is any.
|
||||
if (NumVarArgs() > 0) {
|
||||
Type* vararg_type = CodegenAnyVal::GetUnloweredPtrType(
|
||||
codegen, children_[vararg_start_idx_]->type());
|
||||
prototype.AddArgument("num_var_arg", codegen->GetType(TYPE_INT));
|
||||
prototype.AddArgument("var_arg", vararg_type);
|
||||
}
|
||||
|
||||
// Create a Function* with the generated type. This is only a function
|
||||
// declaration, not a definition, since we do not create any basic blocks or
|
||||
// instructions in it.
|
||||
*udf = prototype.GeneratePrototype(NULL, NULL, false);
|
||||
|
||||
// Associate the dynamically loaded function pointer with the Function* we defined.
|
||||
// This tells LLVM where the compiled function definition is located in memory.
|
||||
codegen->execution_engine()->addGlobalMapping(*udf, fn_ptr);
|
||||
} else if (fn_.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
// In this path, we're running a builtin with the UDF interface. The IR is
|
||||
// in the llvm module.
|
||||
*udf = codegen->GetFunction(fn_.scalar_fn.symbol, false);
|
||||
if (*udf == NULL) {
|
||||
// Builtins symbols should exist unless there is a version mismatch.
|
||||
stringstream ss;
|
||||
ss << "Builtin '" << fn_.name.function_name << "' with symbol '"
|
||||
<< fn_.scalar_fn.symbol << "' does not exist. "
|
||||
<< "Verify that all your impalads are the same version.";
|
||||
return Status(ss.str());
|
||||
}
|
||||
// Builtin functions may use Expr::GetConstant(). Clone the function in case we need
|
||||
// to use it again, and rename it to something more manageable than the mangled name.
|
||||
string demangled_name = SymbolsUtil::DemangleNoArgs((*udf)->getName().str());
|
||||
*udf = codegen->CloneFunction(*udf);
|
||||
(*udf)->setName(demangled_name);
|
||||
InlineConstants(codegen, *udf);
|
||||
*udf = codegen->FinalizeFunction(*udf);
|
||||
DCHECK(*udf != NULL);
|
||||
} else {
|
||||
// We're running an IR UDF.
|
||||
DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR);
|
||||
*udf = codegen->GetFunction(fn_.scalar_fn.symbol, false);
|
||||
if (*udf == NULL) {
|
||||
stringstream ss;
|
||||
ss << "Unable to locate function " << fn_.scalar_fn.symbol << " from LLVM module "
|
||||
<< fn_.hdfs_location;
|
||||
return Status(ss.str());
|
||||
}
|
||||
*udf = codegen->FinalizeFunction(*udf);
|
||||
if (*udf == NULL) {
|
||||
return Status(
|
||||
TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScalarFnCall::GetFunction(LlvmCodeGen* codegen, const string& symbol, void** fn) {
|
||||
if (fn_.binary_type == TFunctionBinaryType::NATIVE ||
|
||||
fn_.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
return LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, symbol, fn,
|
||||
&cache_entry_);
|
||||
if (fn_.binary_type == TFunctionBinaryType::NATIVE
|
||||
|| fn_.binary_type == TFunctionBinaryType::BUILTIN) {
|
||||
return LibCache::instance()->GetSoFunctionPtr(
|
||||
fn_.hdfs_location, symbol, fn, &cache_entry_);
|
||||
} else {
|
||||
DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR);
|
||||
DCHECK(codegen != NULL);
|
||||
|
||||
@@ -48,7 +48,7 @@ class TExprNode;
|
||||
/// - Test cancellation
|
||||
/// - Type descs in UDA test harness
|
||||
/// - Allow more functions to be NULL in UDA test harness
|
||||
class ScalarFnCall: public Expr {
|
||||
class ScalarFnCall : public Expr {
|
||||
public:
|
||||
virtual std::string DebugString() const;
|
||||
|
||||
@@ -117,9 +117,6 @@ class ScalarFnCall: public Expr {
|
||||
return children_.back()->type();
|
||||
}
|
||||
|
||||
/// Loads the native or IR function from HDFS and puts the result in *udf.
|
||||
Status GetUdf(LlvmCodeGen* codegen, llvm::Function** udf);
|
||||
|
||||
/// Loads the native or IR function 'symbol' from HDFS and puts the result in *fn.
|
||||
/// If the function is loaded from an IR module, it cannot be called until the module
|
||||
/// has been JIT'd (i.e. after GetCodegendComputeFn() has been called).
|
||||
|
||||
@@ -76,7 +76,7 @@ void ThrowIfDateOutOfRange(const boost::gregorian::date& date) {
|
||||
|
||||
// This function uses inline asm functions, which we believe to be from the boost library.
|
||||
// Inline asm is not currently supported by JIT, so this function should always be run in
|
||||
// the interpreted mode. This is handled in ScalarFnCall::GetUdf().
|
||||
// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction().
|
||||
TimestampVal TimestampFunctions::FromUtc(FunctionContext* context,
|
||||
const TimestampVal& ts_val, const StringVal& tz_string_val) {
|
||||
if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null();
|
||||
@@ -114,7 +114,7 @@ TimestampVal TimestampFunctions::FromUtc(FunctionContext* context,
|
||||
|
||||
// This function uses inline asm functions, which we believe to be from the boost library.
|
||||
// Inline asm is not currently supported by JIT, so this function should always be run in
|
||||
// the interpreted mode. This is handled in ScalarFnCall::GetUdf().
|
||||
// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction().
|
||||
TimestampVal TimestampFunctions::ToUtc(FunctionContext* context,
|
||||
const TimestampVal& ts_val, const StringVal& tz_string_val) {
|
||||
if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null();
|
||||
|
||||
@@ -310,6 +310,12 @@ string ColumnType::DebugString() const {
|
||||
}
|
||||
}
|
||||
|
||||
vector<ColumnType> ColumnType::FromThrift(const vector<TColumnType>& ttypes) {
|
||||
vector<ColumnType> types;
|
||||
for (const TColumnType& ttype : ttypes) types.push_back(FromThrift(ttype));
|
||||
return types;
|
||||
}
|
||||
|
||||
ostream& operator<<(ostream& os, const ColumnType& type) {
|
||||
os << type.DebugString();
|
||||
return os;
|
||||
|
||||
@@ -147,6 +147,8 @@ struct ColumnType {
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::vector<ColumnType> FromThrift(const std::vector<TColumnType>& ttypes);
|
||||
|
||||
bool operator==(const ColumnType& o) const {
|
||||
if (type != o.type) return false;
|
||||
if (children != o.children) return false;
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
|
||||
#include "testutil/test-udas.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// Don't include Impala internal headers - real UDAs won't include them.
|
||||
#include <udf/udf.h>
|
||||
|
||||
using namespace impala_udf;
|
||||
@@ -48,19 +51,95 @@ void Agg(FunctionContext*, const StringVal&, const DoubleVal&, StringVal*) {}
|
||||
void AggInit(FunctionContext*, StringVal*){}
|
||||
void AggMerge(FunctionContext*, const StringVal&, StringVal*) {}
|
||||
StringVal AggSerialize(FunctionContext*, const StringVal& v) { return v;}
|
||||
StringVal AggFinalize(FunctionContext*, const StringVal& v) { return v;}
|
||||
StringVal AggFinalize(FunctionContext*, const StringVal& v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
|
||||
// Defines AggIntermediate(int) returns BIGINT intermediate CHAR(10)
|
||||
// TODO: StringVal should be replaced with BufferVal in Impala 2.0
|
||||
void AggIntermediate(FunctionContext*, const IntVal&, StringVal*) {}
|
||||
void AggIntermediateUpdate(FunctionContext*, const IntVal&, StringVal*) {}
|
||||
void AggIntermediateInit(FunctionContext*, StringVal*) {}
|
||||
void AggIntermediateMerge(FunctionContext*, const StringVal&, StringVal*) {}
|
||||
BigIntVal AggIntermediateFinalize(FunctionContext*, const StringVal&) {
|
||||
// Defines AggIntermediate(int) returns BIGINT intermediate STRING
|
||||
void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {}
|
||||
void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) {
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
}
|
||||
void AggIntermediateInit(FunctionContext* context, StringVal*) {
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
}
|
||||
void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) {
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
}
|
||||
BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) {
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
return BigIntVal::null();
|
||||
}
|
||||
|
||||
// Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6)
|
||||
// intermediate DECIMAL(3,4)
|
||||
// Useful to test that type parameters are plumbed through.
|
||||
void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) {
|
||||
assert(context->GetNumArgs() == 2);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetArgType(0)->precision == 2);
|
||||
assert(context->GetArgType(0)->scale == 1);
|
||||
assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetIntermediateType().precision == 4);
|
||||
assert(context->GetIntermediateType().scale == 3);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetReturnType().precision == 6);
|
||||
assert(context->GetReturnType().scale == 5);
|
||||
}
|
||||
void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) {
|
||||
assert(context->GetNumArgs() == 2);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetArgType(0)->precision == 2);
|
||||
assert(context->GetArgType(0)->scale == 1);
|
||||
assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetIntermediateType().precision == 4);
|
||||
assert(context->GetIntermediateType().scale == 3);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetReturnType().precision == 6);
|
||||
assert(context->GetReturnType().scale == 5);
|
||||
}
|
||||
void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) {
|
||||
assert(context->GetNumArgs() == 2);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetArgType(0)->precision == 2);
|
||||
assert(context->GetArgType(0)->scale == 1);
|
||||
assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetIntermediateType().precision == 4);
|
||||
assert(context->GetIntermediateType().scale == 3);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetReturnType().precision == 6);
|
||||
assert(context->GetReturnType().scale == 5);
|
||||
}
|
||||
DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) {
|
||||
assert(context->GetNumArgs() == 2);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetArgType(0)->precision == 2);
|
||||
assert(context->GetArgType(0)->scale == 1);
|
||||
assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetIntermediateType().precision == 4);
|
||||
assert(context->GetIntermediateType().scale == 3);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
|
||||
assert(context->GetReturnType().precision == 6);
|
||||
assert(context->GetReturnType().scale == 5);
|
||||
return DecimalVal::null();
|
||||
}
|
||||
|
||||
// Defines MemTest(bigint) return bigint
|
||||
// "Allocates" the specified number of bytes in the update function and frees them in the
|
||||
// serialize function. Useful for testing mem limits.
|
||||
@@ -99,22 +178,57 @@ BigIntVal MemTestFinalize(FunctionContext* context, const BigIntVal& total) {
|
||||
// Defines aggregate function for testing different intermediate/output types that
|
||||
// computes the truncated bigint sum of many floats.
|
||||
void TruncSumInit(FunctionContext* context, DoubleVal* total) {
|
||||
// Arg types should be logical input types of UDA.
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetArgType(-1) == nullptr);
|
||||
assert(context->GetArgType(1) == nullptr);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
*total = DoubleVal(0);
|
||||
}
|
||||
|
||||
void TruncSumUpdate(FunctionContext* context, const DoubleVal& val, DoubleVal* total) {
|
||||
// Arg types should be logical input types of UDA.
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetArgType(-1) == nullptr);
|
||||
assert(context->GetArgType(1) == nullptr);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
total->val += val.val;
|
||||
}
|
||||
|
||||
void TruncSumMerge(FunctionContext* context, const DoubleVal& src, DoubleVal* dst) {
|
||||
// Arg types should be logical input types of UDA.
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetArgType(-1) == nullptr);
|
||||
assert(context->GetArgType(1) == nullptr);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
dst->val += src.val;
|
||||
}
|
||||
|
||||
const DoubleVal TruncSumSerialize(FunctionContext* context, const DoubleVal& total) {
|
||||
// Arg types should be logical input types of UDA.
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetArgType(-1) == nullptr);
|
||||
assert(context->GetArgType(1) == nullptr);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
return total;
|
||||
}
|
||||
|
||||
BigIntVal TruncSumFinalize(FunctionContext* context, const DoubleVal& total) {
|
||||
// Arg types should be logical input types of UDA.
|
||||
assert(context->GetNumArgs() == 1);
|
||||
assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetArgType(-1) == nullptr);
|
||||
assert(context->GetArgType(1) == nullptr);
|
||||
assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
|
||||
assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
|
||||
return BigIntVal(static_cast<int64_t>(total.val));
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#ifndef IMPALA_UDF_TEST_UDAS_H
|
||||
#define IMPALA_UDF_TEST_UDAS_H
|
||||
|
||||
// Don't include Impala internal headers - real UDAs won't include them.
|
||||
#include "udf/udf.h"
|
||||
|
||||
using namespace impala_udf;
|
||||
|
||||
@@ -47,6 +47,12 @@ class RuntimeState;
|
||||
/// This class actually implements the interface of FunctionContext. This is split to
|
||||
/// hide the details from the external header.
|
||||
/// Note: The actual user code does not include this file.
|
||||
///
|
||||
/// Exprs (e.g. UDFs and UDAs) require a FunctionContext to store state related to
|
||||
/// evaluation of the expression. Each FunctionContext is associated with a backend Expr
|
||||
/// or AggFnEvaluator, which is derived from a TExprNode generated by the Impala frontend.
|
||||
/// FunctionContexts are allocated and managed by ExprContext. Exprs shouldn't try to
|
||||
/// create FunctionContext themselves.
|
||||
class FunctionContextImpl {
|
||||
public:
|
||||
/// Create a FunctionContext for a UDF. Caller is responsible for deleting it.
|
||||
|
||||
@@ -34,6 +34,10 @@ int FunctionContext::GetNumArgs() const {
|
||||
return impl_->arg_types_.size();
|
||||
}
|
||||
|
||||
const FunctionContext::TypeDesc& FunctionContext::GetIntermediateType() const {
|
||||
return impl_->intermediate_type_;
|
||||
}
|
||||
|
||||
const FunctionContext::TypeDesc& FunctionContext::GetReturnType() const {
|
||||
return impl_->return_type_;
|
||||
}
|
||||
|
||||
@@ -189,21 +189,25 @@ class FunctionContext {
|
||||
const TypeDesc& GetIntermediateType() const;
|
||||
|
||||
/// Returns the number of arguments to this function (not including the FunctionContext*
|
||||
/// argument).
|
||||
/// argument or the output of a UDA).
|
||||
/// For UDAs, returns the number of logical arguments of the aggregate function, not
|
||||
/// the number of arguments of the C++ function being executed.
|
||||
int GetNumArgs() const;
|
||||
|
||||
/// Returns the type information for the arg_idx-th argument (0-indexed, not including
|
||||
/// the FunctionContext* argument). Returns NULL if arg_idx is invalid.
|
||||
/// For UDAs, returns the logical argument types of the aggregate function, not the
|
||||
/// argument types of the C++ function being executed.
|
||||
const TypeDesc* GetArgType(int arg_idx) const;
|
||||
|
||||
/// Returns true if the arg_idx-th input argument (0 indexed, not including the
|
||||
/// FunctionContext* argument) is a constant (e.g. 5, "string", 1 + 1).
|
||||
/// Returns true if the arg_idx-th input argument (indexed in the same way as
|
||||
/// GetArgType()) is a constant (e.g. 5, "string", 1 + 1).
|
||||
bool IsArgConstant(int arg_idx) const;
|
||||
|
||||
/// Returns a pointer to the value of the arg_idx-th input argument (0 indexed, not
|
||||
/// including the FunctionContext* argument). Returns NULL if the argument is not
|
||||
/// constant. This function can be used to obtain user-specified constants in a UDF's
|
||||
/// Init() or Close() functions.
|
||||
/// Returns a pointer to the value of the arg_idx-th input argument (indexed in the
|
||||
/// same way as GetArgType()). Returns NULL if the argument is not constant. This
|
||||
/// function can be used to obtain user-specified constants in a UDF's Init() or
|
||||
/// Close() functions.
|
||||
AnyVal* GetConstantArg(int arg_idx) const;
|
||||
|
||||
/// TODO: Do we need to add arbitrary key/value metadata. This would be plumbed
|
||||
|
||||
@@ -112,9 +112,14 @@ struct TStringLiteral {
|
||||
1: required string value;
|
||||
}
|
||||
|
||||
// Additional information for aggregate functions.
|
||||
struct TAggregateExpr {
|
||||
// Indicates whether this expr is the merge() of an aggregation.
|
||||
1: required bool is_merge_agg
|
||||
|
||||
// The types of the input arguments to the aggregate function. May differ from the
|
||||
// input expr types if this is the merge() of an aggregation.
|
||||
2: required list<Types.TColumnType> arg_types;
|
||||
}
|
||||
|
||||
// This is essentially a union over the subclasses of Expr.
|
||||
|
||||
@@ -673,7 +673,8 @@ public class AggregateInfo extends AggregateInfoBase {
|
||||
* materialized slots of the output tuple corresponds to the number of materialized
|
||||
* aggregate functions plus the number of grouping exprs. Also checks that the return
|
||||
* types of the aggregate and grouping exprs correspond to the slots in the output
|
||||
* tuple.
|
||||
* tuple and that the input types stored in the merge aggregation are consistent
|
||||
* with the input exprs.
|
||||
*/
|
||||
public void checkConsistency() {
|
||||
ArrayList<SlotDescriptor> slots = outputTupleDesc_.getSlots();
|
||||
@@ -707,6 +708,13 @@ public class AggregateInfo extends AggregateInfoBase {
|
||||
slotType.toString()));
|
||||
++slotIdx;
|
||||
}
|
||||
if (mergeAggInfo_ != null) {
|
||||
// Check that the argument types in mergeAggInfo_ are consistent with input exprs.
|
||||
for (int i = 0; i < aggregateExprs_.size(); ++i) {
|
||||
FunctionCallExpr mergeAggExpr = mergeAggInfo_.aggregateExprs_.get(i);
|
||||
mergeAggExpr.validateMergeAggFn(aggregateExprs_.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -30,6 +30,7 @@ import org.apache.impala.catalog.Type;
|
||||
import org.apache.impala.common.AnalysisException;
|
||||
import org.apache.impala.common.TreeNode;
|
||||
import org.apache.impala.thrift.TAggregateExpr;
|
||||
import org.apache.impala.thrift.TColumnType;
|
||||
import org.apache.impala.thrift.TExprNode;
|
||||
import org.apache.impala.thrift.TExprNodeType;
|
||||
import org.apache.impala.thrift.TFunctionBinaryType;
|
||||
@@ -45,10 +46,12 @@ public class FunctionCallExpr extends Expr {
|
||||
private boolean isAnalyticFnCall_ = false;
|
||||
private boolean isInternalFnCall_ = false;
|
||||
|
||||
// Indicates whether this is a merge aggregation function that should use the merge
|
||||
// instead of the update symbol. This flag also affects the behavior of
|
||||
// resetAnalysisState() which is used during expr substitution.
|
||||
private final boolean isMergeAggFn_;
|
||||
// Non-null iff this is an aggregation function that executes the Merge() step. This
|
||||
// is an analyzed clone of the FunctionCallExpr that executes the Update() function
|
||||
// feeding into this Merge(). This is stored so that we can access the types of the
|
||||
// original input argument exprs. Note that the nullness affects the behaviour of
|
||||
// resetAnalysisState(), which is used during expr substitution.
|
||||
private final FunctionCallExpr mergeAggInputFn_;
|
||||
|
||||
// Printed in toSqlImpl(), if set. Used for merge agg fns.
|
||||
private String label_;
|
||||
@@ -62,15 +65,16 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
|
||||
public FunctionCallExpr(FunctionName fnName, FunctionParams params) {
|
||||
this(fnName, params, false);
|
||||
this(fnName, params, null);
|
||||
}
|
||||
|
||||
private FunctionCallExpr(
|
||||
FunctionName fnName, FunctionParams params, boolean isMergeAggFn) {
|
||||
private FunctionCallExpr(FunctionName fnName, FunctionParams params,
|
||||
FunctionCallExpr mergeAggInputFn) {
|
||||
super();
|
||||
fnName_ = fnName;
|
||||
params_ = params;
|
||||
isMergeAggFn_ = isMergeAggFn;
|
||||
mergeAggInputFn_ =
|
||||
mergeAggInputFn == null ? null : (FunctionCallExpr)mergeAggInputFn.clone();
|
||||
if (params.exprs() != null) children_ = Lists.newArrayList(params_.exprs());
|
||||
}
|
||||
|
||||
@@ -99,12 +103,12 @@ public class FunctionCallExpr extends Expr {
|
||||
Preconditions.checkState(agg.isAnalyzed());
|
||||
Preconditions.checkState(agg.isAggregateFunction());
|
||||
FunctionCallExpr result = new FunctionCallExpr(
|
||||
agg.fnName_, new FunctionParams(false, params), true);
|
||||
agg.fnName_, new FunctionParams(false, params), agg);
|
||||
// Inherit the function object from 'agg'.
|
||||
result.fn_ = agg.fn_;
|
||||
result.type_ = agg.type_;
|
||||
// Set an explicit label based on the input agg.
|
||||
if (agg.isMergeAggFn_) {
|
||||
if (agg.isMergeAggFn()) {
|
||||
result.label_ = agg.label_;
|
||||
} else {
|
||||
// fn(input) becomes fn:merge(input).
|
||||
@@ -123,7 +127,8 @@ public class FunctionCallExpr extends Expr {
|
||||
fnName_ = other.fnName_;
|
||||
isAnalyticFnCall_ = other.isAnalyticFnCall_;
|
||||
isInternalFnCall_ = other.isInternalFnCall_;
|
||||
isMergeAggFn_ = other.isMergeAggFn_;
|
||||
mergeAggInputFn_ =
|
||||
other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone();
|
||||
// Clone the params in a way that keeps the children_ and the params.exprs()
|
||||
// in sync. The children have already been cloned in the super c'tor.
|
||||
if (other.params_.isStar()) {
|
||||
@@ -135,7 +140,7 @@ public class FunctionCallExpr extends Expr {
|
||||
label_ = other.label_;
|
||||
}
|
||||
|
||||
public boolean isMergeAggFn() { return isMergeAggFn_; }
|
||||
public boolean isMergeAggFn() { return mergeAggInputFn_ != null; }
|
||||
|
||||
@Override
|
||||
public void resetAnalysisState() {
|
||||
@@ -144,7 +149,7 @@ public class FunctionCallExpr extends Expr {
|
||||
// intermediate agg type is not the same as the output type. Preserve the original
|
||||
// fn_ such that analyze() hits the special-case code for merge agg fns that
|
||||
// handles this case.
|
||||
if (!isMergeAggFn_) fn_ = null;
|
||||
if (!isMergeAggFn()) fn_ = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -160,7 +165,7 @@ public class FunctionCallExpr extends Expr {
|
||||
public String toSqlImpl() {
|
||||
if (label_ != null) return label_;
|
||||
// Merge agg fns should have an explicit label.
|
||||
Preconditions.checkState(!isMergeAggFn_);
|
||||
Preconditions.checkState(!isMergeAggFn());
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(fnName_).append("(");
|
||||
if (params_.isStar()) sb.append("*");
|
||||
@@ -226,7 +231,12 @@ public class FunctionCallExpr extends Expr {
|
||||
protected void toThrift(TExprNode msg) {
|
||||
if (isAggregateFunction() || isAnalyticFnCall_) {
|
||||
msg.node_type = TExprNodeType.AGGREGATE_EXPR;
|
||||
if (!isAnalyticFnCall_) msg.setAgg_expr(new TAggregateExpr(isMergeAggFn_));
|
||||
List<TColumnType> aggFnArgTypes = Lists.newArrayList();
|
||||
FunctionCallExpr inputAggFn = isMergeAggFn() ? mergeAggInputFn_ : this;
|
||||
for (Expr child: inputAggFn.children_) {
|
||||
aggFnArgTypes.add(child.getType().toThrift());
|
||||
}
|
||||
msg.setAgg_expr(new TAggregateExpr(isMergeAggFn(), aggFnArgTypes));
|
||||
} else {
|
||||
msg.node_type = TExprNodeType.FUNCTION_CALL;
|
||||
}
|
||||
@@ -383,7 +393,7 @@ public class FunctionCallExpr extends Expr {
|
||||
protected void analyzeImpl(Analyzer analyzer) throws AnalysisException {
|
||||
fnName_.analyze(analyzer);
|
||||
|
||||
if (isMergeAggFn_) {
|
||||
if (isMergeAggFn()) {
|
||||
// This is the function call expr after splitting up to a merge aggregation.
|
||||
// The function has already been analyzed so just do the minimal sanity
|
||||
// check here.
|
||||
@@ -524,6 +534,25 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate that the internal state, specifically types, is consistent between the
|
||||
* the Update() and Merge() aggregate functions.
|
||||
*/
|
||||
void validateMergeAggFn(FunctionCallExpr inputAggFn) {
|
||||
Preconditions.checkState(isMergeAggFn());
|
||||
List<Expr> copiedInputExprs = mergeAggInputFn_.getChildren();
|
||||
List<Expr> inputExprs = inputAggFn.getChildren();
|
||||
Preconditions.checkState(copiedInputExprs.size() == inputExprs.size());
|
||||
for (int i = 0; i < inputExprs.size(); ++i) {
|
||||
Type copiedInputType = copiedInputExprs.get(i).getType();
|
||||
Type inputType = inputExprs.get(i).getType();
|
||||
Preconditions.checkState(copiedInputType.equals(inputType),
|
||||
String.format("Copied expr %s arg type %s differs from input expr type %s " +
|
||||
"in original expr %s", toSql(), copiedInputType.toSql(),
|
||||
inputType.toSql(), inputAggFn.toSql()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr clone() { return new FunctionCallExpr(this); }
|
||||
}
|
||||
|
||||
@@ -69,3 +69,22 @@ from functional.alltypesagg
|
||||
---- TYPES
|
||||
bigint,bigint
|
||||
====
|
||||
---- QUERY
|
||||
# Test that all types are exposed via the FunctionContext correctly.
|
||||
# This relies on asserts in the UDA funciton
|
||||
select agg_intermediate(int_col), count(*)
|
||||
from functional.alltypesagg
|
||||
---- RESULTS
|
||||
NULL,11000
|
||||
---- TYPES
|
||||
bigint,bigint
|
||||
====
|
||||
---- QUERY
|
||||
# Test that all types are exposed via the FunctionContext correctly.
|
||||
# This relies on asserts in the UDA funciton
|
||||
select agg_decimal_intermediate(cast(d1 as decimal(2,1)), 2), count(*)
|
||||
from functional.decimal_tbl
|
||||
---- RESULTS
|
||||
NULL,5
|
||||
---- TYPES
|
||||
decimal,bigint
|
||||
|
||||
@@ -93,6 +93,16 @@ update_fn='ToggleNullUpdate' merge_fn='ToggleNullMerge';
|
||||
create aggregate function {database}.count_nulls(bigint)
|
||||
returns bigint location '{location}'
|
||||
update_fn='CountNullsUpdate' merge_fn='CountNullsMerge';
|
||||
|
||||
create aggregate function {database}.agg_intermediate(int)
|
||||
returns bigint intermediate string location '{location}'
|
||||
init_fn='AggIntermediateInit' update_fn='AggIntermediateUpdate'
|
||||
merge_fn='AggIntermediateMerge' finalize_fn='AggIntermediateFinalize';
|
||||
|
||||
create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int)
|
||||
returns decimal(6,5) intermediate decimal(4,3) location '{location}'
|
||||
init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate'
|
||||
merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize';
|
||||
"""
|
||||
|
||||
# Create test UDF functions in {database} from library {location}
|
||||
|
||||
Reference in New Issue
Block a user