diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index fb122f5cd..c8dd6f5e4 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -102,6 +102,21 @@ string LiteralToString(double val) { return ss.str(); } +// For writing C++ std::strings as impala literals, we have to ensure that characters like +// single-quote are escaped properly. To do this, we escape every character into its octal +// equivalent: \PQR for some octal digits P, Q, and R. Currently, this only works for +// ASCII literals. +string StringToOctalLiteral(const string& s) { + string result(4 * s.size(), 0); + for (int i = 0; i < s.size(); ++i) { + result[4 * i] = '\\'; + result[4 * i + 1] = '0' + (s[i] / 64); + result[4 * i + 2] = '0' + ((s[i] / 8) % 8); + result[4 * i + 3] = '0' + (s[i] % 8); + } + return result; +} + // Override the time zone for the duration of the scope. The time zone is overridden // using an environment variable there is no risk of making a permanent system change // and no special permissions are needed. This is not thread-safe. @@ -2081,6 +2096,36 @@ TEST_F(ExprTest, StringFunctions) { big_str[ColumnType::MAX_VARCHAR_LENGTH] = '\0'; sprintf(query, "cast('%sxxx' as VARCHAR(%d))", big_str, ColumnType::MAX_VARCHAR_LENGTH); TestStringValue(query, big_str); + + // base64{en,de}code + + // Test some known values of base64{en,de}code + TestIsNull("base64encode(NULL)", TYPE_STRING); + TestIsNull("base64decode(NULL)", TYPE_STRING); + TestStringValue("base64encode('')", ""); + TestStringValue("base64decode('')", ""); + TestStringValue("base64encode('a')","YQ=="); + TestStringValue("base64decode('YQ==')","a"); + TestStringValue("base64encode('alpha')","YWxwaGE="); + TestStringValue("base64decode('YWxwaGE=')","alpha"); + TestIsNull("base64decode('YWxwaGE')", TYPE_STRING); + TestIsNull("base64decode('YWxwaGE%')", TYPE_STRING); + + // Test random short strings. + srand(0); + for (int length = 1; length < 100; ++length) { + for (int iteration = 0; iteration < 10; ++iteration) { + string raw(length, ' '); + for (int j = 0; j < length; ++j) { + raw[j] = rand() % 128; + } + const string as_octal = StringToOctalLiteral(raw); + TestValue("length(base64encode('" + as_octal + "')) > length('" + as_octal + "')", + TYPE_BOOLEAN, true); + TestValue("base64decode(base64encode('" + as_octal + "')) = '" + as_octal + "'", + TYPE_BOOLEAN, true); + } + } } TEST_F(ExprTest, StringRegexpFunctions) { diff --git a/be/src/exprs/string-functions.cc b/be/src/exprs/string-functions.cc index 26a642ab4..542d1d850 100644 --- a/be/src/exprs/string-functions.cc +++ b/be/src/exprs/string-functions.cc @@ -24,6 +24,7 @@ #include "exprs/expr.h" #include "runtime/string-value.inline.h" #include "runtime/tuple-row.h" +#include "sasl/saslutil.h" #include "util/url-parser.h" #include "common/names.h" @@ -792,4 +793,72 @@ StringVal StringFunctions::SplitPart(FunctionContext* context, return StringVal(); } +StringVal StringFunctions::Base64Encode(FunctionContext* ctx, const StringVal& str) { + if (str.is_null) return StringVal::null(); + if (str.len == 0) return StringVal(ctx, 0); + // Base64 encoding turns every 3 bytes into 4 characters. If the length is not divisible + // by 3, it pads the input with extra 0 bytes until it is divisible by 3. One more + // character must be allocated to account for sasl_encode64's null-padding of its + // output. + const unsigned out_max = 1 + 4 * ((static_cast(str.len) + 2) / 3); + if (UNLIKELY(out_max > static_cast(std::numeric_limits::max()))) { + stringstream ss; + ss << "Could not base64 encode a string of length " << str.len; + ctx->AddWarning(ss.str().c_str()); + return StringVal::null(); + } + StringVal result(ctx, out_max); + if (UNLIKELY(result.is_null)) return result; + unsigned out_len = 0; + const int encode_result = sasl_encode64(reinterpret_cast(str.ptr), str.len, + reinterpret_cast(result.ptr), out_max, &out_len); + if (UNLIKELY(encode_result != SASL_OK || out_len != out_max - 1)) { + stringstream ss; + ss << "Could not base64 encode input in space " << out_max + << "; actual output length " << out_len; + ctx->AddWarning(ss.str().c_str()); + return StringVal::null(); + } + result.len = out_len; + return result; +} + +StringVal StringFunctions::Base64Decode(FunctionContext* ctx, const StringVal& str) { + if (str.is_null) return StringVal::null(); + if (0 == str.len) return StringVal(ctx, 0); + // Base64 decoding turns every 4 characters into 3 bytes. If the last character of the + // encoded string is '=', that character (which represents 6 bits) and the last two bits + // of the previous character is ignored, for a total of 8 ignored bits, therefore + // producing one fewer byte of output. This is repeated if the second-to-last character + // is '='. One more byte must be allocated to account for sasl_decode64's null-padding + // of its output. + if (UNLIKELY((str.len & 3) != 0)) { + stringstream ss; + ss << "Invalid base64 string; input length is " << str.len + << ", which is not a multiple of 4."; + ctx->AddWarning(ss.str().c_str()); + return StringVal::null(); + } + unsigned out_max = 1 + 3 * (str.len / 4); + if (static_cast(str.ptr[str.len - 1]) == '=') { + --out_max; + if (static_cast(str.ptr[str.len - 2]) == '=') { + --out_max; + } + } + StringVal result(ctx, out_max); + if (UNLIKELY(result.is_null)) return result; + unsigned out_len = 0; + const int decode_result = sasl_decode64(reinterpret_cast(str.ptr), str.len, + reinterpret_cast(result.ptr), out_max, &out_len); + if (UNLIKELY(decode_result != SASL_OK || out_len != out_max - 1)) { + stringstream ss; + ss << "Could not base64 decode input in space " << out_max + << "; actual output length " << out_len; + ctx->AddWarning(ss.str().c_str()); + return StringVal::null(); + } + result.len = out_len; + return result; +} } diff --git a/be/src/exprs/string-functions.h b/be/src/exprs/string-functions.h index df7e5557c..d8419ab60 100644 --- a/be/src/exprs/string-functions.h +++ b/be/src/exprs/string-functions.h @@ -100,6 +100,9 @@ class StringFunctions { /// both ends of string 'str'. static StringVal BTrimString(FunctionContext* ctx, const StringVal& str, const StringVal& chars_to_trim); + + static StringVal Base64Encode(FunctionContext* ctx, const StringVal& str); + static StringVal Base64Decode(FunctionContext* ctx, const StringVal& str); }; } #endif diff --git a/common/function-registry/impala_functions.py b/common/function-registry/impala_functions.py index 73369bddd..879ff8a77 100644 --- a/common/function-registry/impala_functions.py +++ b/common/function-registry/impala_functions.py @@ -403,6 +403,8 @@ visible_functions = [ 'impala::StringFunctions::Substring'], [['split_part'], 'STRING', ['STRING', 'STRING', 'BIGINT'], 'impala::StringFunctions::SplitPart'], + [['base64encode'], 'STRING', ['STRING'], 'impala::StringFunctions::Base64Encode'], + [['base64decode'], 'STRING', ['STRING'], 'impala::StringFunctions::Base64Decode'], # left and right are key words, leave them out for now. [['strleft'], 'STRING', ['STRING', 'BIGINT'], 'impala::StringFunctions::Left'], [['strright'], 'STRING', ['STRING', 'BIGINT'], 'impala::StringFunctions::Right'], diff --git a/testdata/workloads/functional-query/queries/QueryTest/exprs.test b/testdata/workloads/functional-query/queries/QueryTest/exprs.test index 373815ecd..7234d4dfe 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/exprs.test +++ b/testdata/workloads/functional-query/queries/QueryTest/exprs.test @@ -2415,3 +2415,42 @@ NULL ---- TYPES timestamp ==== +---- QUERY +# base64 encoding/decoding +select count(*) from functional.alltypes +where length(string_col) > 0 && +length(base64encode(string_col)) <= length(string_col) +---- RESULTS +0 +---- TYPES +BIGINT +==== +---- QUERY +# base64 encoding/decoding +select count (*) from functional.alltypes +where base64decode(base64encode(string_col)) IS DISTINCT FROM string_col +---- RESULTS +0 +---- TYPES +BIGINT +==== +---- QUERY +# base64 decoding a string of invalid length (must be divisible by 4) +select base64decode('foo') +---- RESULTS +'NULL' +---- TYPES +STRING +---- ERRORS +UDF WARNING: Invalid base64 string; input length is 3, which is not a multiple of 4. +==== +---- QUERY +# base64 decoding a string with invalid characters +select base64decode('abc%') +---- RESULTS +'NULL' +---- TYPES +STRING +---- ERRORS +UDF WARNING: Could not base64 decode input in space 4; actual output length 0 +==== \ No newline at end of file