From 3850d49711b88091101cfc3d89da28c76a17b04d Mon Sep 17 00:00:00 2001 From: stiga-huang Date: Mon, 16 Aug 2021 18:04:19 +0800 Subject: [PATCH] IMPALA-9662,IMPALA-2019(part-3): Support UTF-8 mode in mask functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mask functions are used in Ranger column masking policies to mask sensitive data. There are 5 mask functions: mask(), mask_first_n(), mask_last_n(), mask_show_first_n(), mask_show_last_n(). Take mask() as an example, by default, it will mask uppercase to 'X', lowercase to 'x', digits to 'n' and leave other characters unmasked. For masking all characters to '*', we can use mask(my_col, '*', '*', '*', '*'); The current implementations mask strings byte-to-byte, which have inconsistent results with Hive when the string contains unicode characters: mask('中国', '*', '*', '*', '*') => '******' Each Chinese character is encoded into 3 bytes in UTF-8 so we get the above result. The result in Hive is '**' since there are two Chinese characters. This patch provides consistent masking behavior with Hive for strings under the UTF-8 mode, i.e., set UTF8_MODE=true. In UTF-8 mode, the masked unit of a string is a unicode code point. Implementation - Extends the existing MaskTransform function to deal with unicode code points(represented by uint32_t). - Extends the existing GetFirstChar function to get the code point of given masked charactors in UTF-8 mode. - Implement a MaskSubStrUtf8 method as the core functionality. - Swith to use MaskSubStrUtf8 instead of MaskSubStr in UTF-8 mode. - For better testing, this patch also adds an overload for all mask functions for only masking other chars but keeping the upper/lower/digit chars unmasked. E.g. mask({col}, -1, -1, -1, 'X'). Tests - Add BE tests in expr-test - Add e2e tests in utf8-string-functions.test Change-Id: I1276eccc94c9528507349b155a51e76f338367d5 Reviewed-on: http://gerrit.cloudera.org:8080/17780 Reviewed-by: Impala Public Jenkins Tested-by: Impala Public Jenkins --- CMakeLists.txt | 2 +- be/src/exprs/expr-test.cc | 59 ++++ be/src/exprs/mask-functions-ir.cc | 305 +++++++++++++++--- be/src/exprs/mask-functions.h | 30 ++ common/function-registry/impala_functions.py | 10 + .../QueryTest/utf8-string-functions.test | 12 + 6 files changed, 364 insertions(+), 54 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index de769c5f1..571886c83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,7 +166,7 @@ function(IMPALA_ADD_THIRDPARTY_LIB NAME HEADER STATIC_LIB SHARED_LIB) endfunction() -find_package(Boost REQUIRED COMPONENTS thread regex filesystem system date_time random) +find_package(Boost REQUIRED COMPONENTS thread regex filesystem system date_time random locale) # Mark Boost as a system header to avoid compile warnings. include_directories(SYSTEM ${Boost_INCLUDE_DIRS}) message(STATUS "Boost include dir: " ${Boost_INCLUDE_DIRS}) diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index 46f2113a9..a7c0aa784 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -10646,6 +10646,65 @@ TEST_P(ExprTest, MaskHashTest) { TestIsNull("mask_hash(cast('2016-04-20' as timestamp))", TYPE_TIMESTAMP); } +TEST_P(ExprTest, Utf8MaskTest) { + executor_->PushExecOption("utf8_mode=true"); + // Default is no masking for other chars so Chinese charactors are unmasked. + TestStringValue("mask('hello李小龙')", "xxxxx李小龙"); + // Keeps upper, lower, digit chars and masks other chars as 'x'. + TestStringValue("mask('hello李小龙', -1, -1, -1, 'X')", "helloXXX"); + TestStringValue("mask_last_n('hello李小龙', 4, -1, -1, -1, 'x')", "helloxxx"); + TestStringValue("mask_last_n('hello李小龙', 2, -1, -1, -1, 'x')", "hello李xx"); + TestStringValue("mask_last_n('hello李小龙', 4, 'x', 'x', 'x', 'X')", "hellxXXX"); + TestStringValue("mask_show_first_n('hello李小龙', 6, 'x', 'x', 'x', 'X')", + "hello李XX"); + TestStringValue("mask_show_first_n('hello李小龙', 4, -1, -1, -1, 'X')", "helloXXX"); + TestStringValue("mask_show_first_n('hello李小龙', 4, 'x', 'x', 'x', 'X')", + "hellxXXX"); + TestStringValue("mask_first_n('hello李小龙', 5)", "xxxxx李小龙"); + // Default is no masking for other chars so Chinese charactors are unmasked. + TestStringValue("mask_first_n('hello李小龙', 6)", "xxxxx李小龙"); + TestStringValue("mask_first_n('hello李小龙', 6, 'x', 'x', 'x', 'X')", + "xxxxxX小龙"); + TestStringValue("mask_show_last_n('hello李小龙', 2, 'x', 'x', 'x', 'X')", + "xxxxxX小龙"); + TestStringValue("mask_show_last_n('hello李小龙', 4, 'x', 'x', 'x', 'X')", + "xxxxo李小龙"); + + // Test masking unicode upper/lower cases. + TestStringValue("mask('abcd áäèü ABCD ÁÄÈÜ')", "xxxx xxxx XXXX XXXX"); + TestStringValue("mask('Ich möchte ein Bier. Tschüss')", + "Xxx xxxxxx xxx Xxxx. Xxxxxxx"); + TestStringValue("mask('Hungarian áéíöóőüúű ÁÉÍÖÓŐÜÚŰ')", + "Xxxxxxxxx xxxxxxxxx XXXXXXXXX"); + TestStringValue("mask('German äöüß ÄÖÜẞ')", "Xxxxxx xxxx XXXX"); + TestStringValue( + "mask('French àâæçéèêëïîôœùûüÿ ÀÂÆÇÉÈÊËÏÎÔŒÙÛÜŸ')", + "Xxxxxx xxxxxxxxxxxxxxxx XXXXXXXXXXXXXXXX"); + TestStringValue("mask('Greek αβξδ άέήώ ΑΒΞΔ ΆΈΉΏ 1234')", + "Xxxxx xxxx xxxx XXXX XXXX nnnn"); + TestStringValue("mask_first_n('áéíöóőüúű')", "xxxxóőüúű"); + TestStringValue("mask_show_first_n('áéíöóőüúű')", "áéíöxxxxx"); + TestStringValue("mask_last_n('áéíöóőüúű')", "áéíöóxxxx"); + TestStringValue("mask_show_last_n('áéíöóőüúű')", "xxxxxőüúű"); + + // Test masking to unicode code points. Specify -1(unmask) for masking upper/lower/digit + // chars. + TestStringValue("mask('hello李小龙', -1, -1, -1, '某')", "hello某某某"); + TestStringValue("mask_last_n('hello李小龙', 4, -1, -1, -1, '某')", + "hello某某某"); + TestStringValue("mask_last_n('hello李小龙', 2, -1, -1, -1, '某')", + "hello李某某"); + TestStringValue("mask_show_first_n('hello李小龙', 4, -1, -1, -1, '某')", + "hello某某某"); + TestStringValue("mask_show_first_n('hello李小龙', 6, -1, -1, -1, '某')", + "hello李某某"); + TestStringValue("mask_first_n('李小龙hello', 4, -1, -1, -1, '某')", + "某某某hello"); + TestStringValue("mask_show_last_n('李小龙hello', 5, -1, -1, -1, '某')", + "某某某hello"); + executor_->PopExecOption(); +} + TEST_P(ExprTest, Utf8Test) { // Verifies utf8_length() counts length by UTF-8 characters instead of bytes. // '你' and '好' are both encoded into 3 bytes. diff --git a/be/src/exprs/mask-functions-ir.cc b/be/src/exprs/mask-functions-ir.cc index 2bbde4e24..c96be187d 100644 --- a/be/src/exprs/mask-functions-ir.cc +++ b/be/src/exprs/mask-functions-ir.cc @@ -17,6 +17,8 @@ #include "exprs/mask-functions.h" +#include +#include #include #include #include @@ -31,6 +33,7 @@ using namespace impala; using namespace impala_udf; +using namespace boost::locale; const static int CHAR_COUNT = 4; const static int MASKED_UPPERCASE = 'X'; @@ -43,19 +46,43 @@ const static int MASKED_MONTH_COMPONENT_VAL = 0; const static int MASKED_YEAR_COMPONENT_VAL = 1; const static int UNMASKED_VAL = -1; -/// Mask the given char depending on its type. UNMASKED_VAL(-1) means keeping the -/// original value. -static inline uint8_t MaskTransform(uint8_t val, int masked_upper_char, - int masked_lower_char, int masked_digit_char, int masked_other_char) { - if ('A' <= val && val <= 'Z') { +/// Masks the given unicode code point depending on its range and the (optional) given +/// locale. By default, if no locale is provided, i.e. loc == nullptr, +/// lowercase/uppercase/digit characters are only recognized in ascii character set. +/// UNMASKED_VAL(-1) means keeping the original value. +/// Returns the masked code point. +static inline uint32_t MaskTransform(uint32_t val, int masked_upper_char, + int masked_lower_char, int masked_digit_char, int masked_other_char, + std::locale* loc = nullptr) { + // Fast code path for masking ascii characters only. + if (loc == nullptr) { + if ('A' <= val && val <= 'Z') { + if (masked_upper_char == UNMASKED_VAL) return val; + return masked_upper_char; + } + if ('a' <= val && val <= 'z') { + if (masked_lower_char == UNMASKED_VAL) return val; + return masked_lower_char; + } + if ('0' <= val && val <= '9') { + if (masked_digit_char == UNMASKED_VAL) return val; + return masked_digit_char; + } + if (masked_other_char == UNMASKED_VAL) return val; + return masked_other_char; + } + // Check facet existence to avoid predicates throws exception. + DCHECK(std::has_facet>(*loc)) + << "Facet not found for locale " << loc->name(); + if (isupper((wchar_t)val, *loc)) { if (masked_upper_char == UNMASKED_VAL) return val; return masked_upper_char; } - if ('a' <= val && val <= 'z') { + if (islower((wchar_t)val, *loc)) { if (masked_lower_char == UNMASKED_VAL) return val; return masked_lower_char; } - if ('0' <= val && val <= '9') { + if (isdigit((wchar_t)val, *loc)) { if (masked_digit_char == UNMASKED_VAL) return val; return masked_digit_char; } @@ -64,7 +91,7 @@ static inline uint8_t MaskTransform(uint8_t val, int masked_upper_char, } /// Mask the substring in range [start, end) of the given string value. Using rules in -/// 'MaskTransform'. +/// 'MaskTransform'. Indices are counted in bytes. static StringVal MaskSubStr(FunctionContext* ctx, const StringVal& val, int start, int end, int masked_upper_char, int masked_lower_char, int masked_digit_char, int masked_other_char) { @@ -82,6 +109,108 @@ static StringVal MaskSubStr(FunctionContext* ctx, const StringVal& val, return result; } +/// Checks whether the unicode code point is malformed, i.e. illegal or incomplete, and +/// warns if it is. Returns true if any warning is added. +static bool CheckAndWarnCodePoint(FunctionContext* ctx, uint32_t code_point) { + if (code_point == utf::illegal || code_point == utf::incomplete) { + ctx->AddWarning(Substitute("String contains $0 code point. Return NULL.", + code_point == utf::illegal ? "illegal" : "incomplete").c_str()); + return true; + } + return false; +} + +/// Mask the substring in range [start, end) of the given string value. Using rules in +/// 'MaskTransform'. Indices are counted in UTF-8 code points. +static StringVal MaskSubStrUtf8(FunctionContext* ctx, const StringVal& val, + int start, int end, int masked_upper_char, int masked_lower_char, + int masked_digit_char, int masked_other_char) { + DCHECK_GE(start, 0); + DCHECK_LT(start, end); + DCHECK_LE(end, val.len); + const char* p_start = reinterpret_cast(val.ptr); + const char* p_end = p_start + val.len; + const char* p = p_start; + utf8_codecvt::state_type cvt_state; + int char_cnt = 0; + // Skip leading 'start' code points. Leading bytes will be copied directly. + while (char_cnt < start && p != p_end) { + uint32_t codepoint = utf8_codecvt::to_unicode(cvt_state, p, p_end); + if (CheckAndWarnCodePoint(ctx, codepoint)) return StringVal::null(); + ++char_cnt; + } + // Calculating the result length in bytes. + int result_bytes = p - p_start; + int leading_bytes = result_bytes; + // Collect code points at range [start, end - 1) and mask them. + vector masked_code_points; + // Create unicode locale for checking upper/lower cases or digits. + // TODO(quanlong): Avoid creating this everytime if this is time/resource-consuming. + boost::locale::generator gen; + unique_ptr loc = make_unique(gen("en_US.UTF-8")); + // Check facet existence to avoid predicates throws exception. + if (!std::has_facet>(*loc)) { + ctx->SetError("Cannot mask unicode strings since locale en_US.UTF-8 not found!"); + return StringVal(); + } + while (char_cnt < end && p != p_end) { + // Parse and get the first code point in string range [p, p_end). + // 'to_unicode' will update the pointer 'p'. + uint32_t codepoint = utf8_codecvt::to_unicode(cvt_state, p, p_end); + if (CheckAndWarnCodePoint(ctx, codepoint)) return StringVal::null(); + codepoint = MaskTransform(codepoint, masked_upper_char, masked_lower_char, + masked_digit_char, masked_other_char, loc.get()); + masked_code_points.push_back(codepoint); + result_bytes += utf::utf_traits::width(codepoint); + ++char_cnt; + } + // Trailing bytes will be copied directly without masking. + int tail_len = p_end - p; + result_bytes += tail_len; + + StringVal result(ctx, result_bytes); + if (UNLIKELY(result.is_null)) return result; + // Copy leading bytes. + Ubsan::MemCpy(result.ptr, val.ptr, leading_bytes); + // Converting masked code points to UTF-8 encoded bytes. + char* ptr = reinterpret_cast(result.ptr) + leading_bytes; + p_end = reinterpret_cast(result.ptr) + result_bytes; + for (uint32_t c : masked_code_points) { + uint32_t width = utf8_codecvt::from_unicode(cvt_state, c, ptr, p_end); + DCHECK(width != utf::illegal && width != utf::incomplete); + ptr += width; + DCHECK(ptr <= p_end); + } + // Copy trailing bytes. + if (tail_len > 0) { + DCHECK(ptr < p_end); + Ubsan::MemCpy(ptr, val.ptr + val.len - tail_len, tail_len); + } + result.len = result_bytes; + return result; +} + +/// Counting code points in the UTF-8 encoded string using the same method, 'to_unicode', +/// as MaskSubStrUtf8 uses. So we can have a consistent behavior. +/// Returns -1 if the string contains malformed(illegal/incomplete) code points. +static int GetUtf8CodePointCount(FunctionContext* ctx, const StringVal& val) { + utf8_codecvt::state_type cvt_state; + const char* p = reinterpret_cast(val.ptr); + const char* p_end = p + val.len; + int char_cnt = 0; + while (p != p_end) { + uint32_t c = utf8_codecvt::to_unicode(cvt_state, p, p_end); + if (c == utf::illegal || c == utf::incomplete) { + ctx->SetError(Substitute("The $0-th code point $1 is $2", + char_cnt, AnyValUtil::ToString(val), + c == utf::illegal ? "illegal" : "incomplete").c_str()); + return -1; + } + ++char_cnt; + } + return char_cnt; +} + /// Mask the given string except the first 'un_mask_char_count' chars. Ported from /// org.apache.hadoop.hive.ql.udf.generic.GenericUDFMaskShowFirstN. static inline StringVal MaskShowFirstNImpl(FunctionContext* ctx, const StringVal& val, @@ -90,7 +219,11 @@ static inline StringVal MaskShowFirstNImpl(FunctionContext* ctx, const StringVal // To be consistent with Hive, negative char_count is treated as 0. if (un_mask_char_count < 0) un_mask_char_count = 0; if (val.is_null || val.len == 0 || un_mask_char_count >= val.len) return val; - return MaskSubStr(ctx, val, un_mask_char_count, val.len, masked_upper_char, + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) { + return MaskSubStr(ctx, val, un_mask_char_count, val.len, masked_upper_char, + masked_lower_char, masked_digit_char, masked_other_char); + } + return MaskSubStrUtf8(ctx, val, un_mask_char_count, val.len, masked_upper_char, masked_lower_char, masked_digit_char, masked_other_char); } @@ -102,8 +235,14 @@ static inline StringVal MaskShowLastNImpl(FunctionContext* ctx, const StringVal& // To be consistent with Hive, negative char_count is treated as 0. if (un_mask_char_count < 0) un_mask_char_count = 0; if (val.is_null || val.len == 0 || un_mask_char_count >= val.len) return val; - return MaskSubStr(ctx, val, 0, val.len - un_mask_char_count, masked_upper_char, - masked_lower_char, masked_digit_char, masked_other_char); + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) { + return MaskSubStr(ctx, val, 0, val.len - un_mask_char_count, masked_upper_char, + masked_lower_char, masked_digit_char, masked_other_char); + } + int end = GetUtf8CodePointCount(ctx, val) - un_mask_char_count; + if (end <= 0) return val; + return MaskSubStrUtf8(ctx, val, 0, end, masked_upper_char, masked_lower_char, + masked_digit_char, masked_other_char); } /// Mask the first 'mask_char_count' chars of the given string. Ported from @@ -113,7 +252,11 @@ static inline StringVal MaskFirstNImpl(FunctionContext* ctx, const StringVal& va int masked_digit_char, int masked_other_char) { if (mask_char_count <= 0 || val.is_null || val.len == 0) return val; if (mask_char_count > val.len) mask_char_count = val.len; - return MaskSubStr(ctx, val, 0, mask_char_count, masked_upper_char, + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) { + return MaskSubStr(ctx, val, 0, mask_char_count, masked_upper_char, masked_lower_char, + masked_digit_char, masked_other_char); + } + return MaskSubStrUtf8(ctx, val, 0, mask_char_count, masked_upper_char, masked_lower_char, masked_digit_char, masked_other_char); } @@ -124,8 +267,14 @@ static inline StringVal MaskLastNImpl(FunctionContext* ctx, const StringVal& val int masked_digit_char, int masked_other_char) { if (mask_char_count <= 0 || val.is_null || val.len == 0) return val; if (mask_char_count > val.len) mask_char_count = val.len; - return MaskSubStr(ctx, val, val.len - mask_char_count, val.len, masked_upper_char, - masked_lower_char, masked_digit_char, masked_other_char); + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) { + return MaskSubStr(ctx, val, val.len - mask_char_count, val.len, masked_upper_char, + masked_lower_char, masked_digit_char, masked_other_char); + } + int start = GetUtf8CodePointCount(ctx, val) - mask_char_count; + if (start < 0) start = 0; + return MaskSubStrUtf8(ctx, val, start, val.len, masked_upper_char, masked_lower_char, + masked_digit_char, masked_other_char); } /// Mask the whole given string. Ported from @@ -134,8 +283,12 @@ static inline StringVal MaskImpl(FunctionContext* ctx, const StringVal& val, int masked_upper_char, int masked_lower_char, int masked_digit_char, int masked_other_char) { if (val.is_null || val.len == 0) return val; - return MaskSubStr(ctx, val, 0, val.len, masked_upper_char, - masked_lower_char, masked_digit_char, masked_other_char); + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) { + return MaskSubStr(ctx, val, 0, val.len, masked_upper_char, + masked_lower_char, masked_digit_char, masked_other_char); + } + return MaskSubStrUtf8(ctx, val, 0, val.len, masked_upper_char, masked_lower_char, + masked_digit_char, masked_other_char); } static inline int GetNumDigits(int64_t val) { @@ -254,10 +407,26 @@ static DateVal MaskImpl(FunctionContext* ctx, const DateVal& val, int day_value, return DateValue(year, month, day).ToDateVal(); } -static inline uint8_t GetFirstChar(const StringVal& str, uint8_t default_value) { +/// Gets the first character of 'str'. Returns 'default_value' if 'str' is empty. +/// In UTF-8 mode, the first code point is returned. +/// Otherwise, the first char is returned. +static inline uint32_t GetFirstChar(FunctionContext* ctx, const StringVal& str, + uint32_t default_value) { // To be consistent with Hive, empty string is converted to default value. String with // length > 1 will only use its first char. - return str.len == 0 ? default_value : str.ptr[0]; + if (str.len == 0) return default_value; + if (!ctx->impl()->GetConstFnAttr(FunctionContextImpl::UTF8_MODE)) return str.ptr[0]; + + utf8_codecvt::state_type cvt_state; + const char* p = reinterpret_cast(str.ptr); + uint32_t c = utf8_codecvt::to_unicode(cvt_state, p, p + str.len); + if (c == utf::illegal || c == utf::incomplete) { + string msg = Substitute("$0 unicode code point found in the beginning of $1", + c == utf::illegal ? "Illegal" : "Incomplete", AnyValUtil::ToString(str)); + ctx->SetError(msg.c_str()); + return default_value; + } + return c; } /// Get digit (masked_number) from StringVal. Only accept digits or -1. @@ -288,10 +457,16 @@ StringVal MaskFunctions::MaskShowFirstN(FunctionContext* ctx, const StringVal& v const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const StringVal& other_char) { return MaskShowFirstNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), - GetFirstChar(other_char, MASKED_OTHER_CHAR)); + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); +} +StringVal MaskFunctions::MaskShowFirstN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char) { + return MaskShowFirstNImpl(ctx, val, char_count.val, upper_char.val, lower_char.val, + digit_char.val, GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); } StringVal MaskFunctions::MaskShowFirstN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, @@ -305,9 +480,9 @@ StringVal MaskFunctions::MaskShowFirstN(FunctionContext* ctx, const StringVal& v const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char) { return MaskShowFirstNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), other_char.val); } StringVal MaskFunctions::MaskShowFirstN(FunctionContext* ctx, const StringVal& val, @@ -369,10 +544,16 @@ StringVal MaskFunctions::MaskShowLastN(FunctionContext* ctx, const StringVal& va const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const StringVal& other_char) { return MaskShowLastNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), - GetFirstChar(other_char, MASKED_OTHER_CHAR)); + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); +} +StringVal MaskFunctions::MaskShowLastN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char) { + return MaskShowLastNImpl(ctx, val, char_count.val, upper_char.val, lower_char.val, + digit_char.val, GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); } StringVal MaskFunctions::MaskShowLastN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, @@ -386,9 +567,9 @@ StringVal MaskFunctions::MaskShowLastN(FunctionContext* ctx, const StringVal& va const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char) { return MaskShowLastNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), other_char.val); } StringVal MaskFunctions::MaskShowLastN(FunctionContext* ctx, const StringVal& val, @@ -440,10 +621,16 @@ StringVal MaskFunctions::MaskFirstN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const StringVal& other_char) { return MaskFirstNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), - GetFirstChar(other_char, MASKED_OTHER_CHAR)); + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); +} +StringVal MaskFunctions::MaskFirstN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char) { + return MaskFirstNImpl(ctx, val, char_count.val, upper_char.val, lower_char.val, + digit_char.val, GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); } StringVal MaskFunctions::MaskFirstN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, @@ -457,9 +644,9 @@ StringVal MaskFunctions::MaskFirstN(FunctionContext* ctx, const StringVal& val, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char) { return MaskFirstNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), other_char.val); } StringVal MaskFunctions::MaskFirstN(FunctionContext* ctx, const StringVal& val, @@ -511,10 +698,16 @@ StringVal MaskFunctions::MaskLastN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const StringVal& other_char) { return MaskLastNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), - GetFirstChar(other_char, MASKED_OTHER_CHAR)); + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); +} +StringVal MaskFunctions::MaskLastN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char) { + return MaskLastNImpl(ctx, val, char_count.val, upper_char.val, lower_char.val, + digit_char.val, GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); } StringVal MaskFunctions::MaskLastN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, @@ -528,9 +721,9 @@ StringVal MaskFunctions::MaskLastN(FunctionContext* ctx, const StringVal& val, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char) { return MaskLastNImpl(ctx, val, char_count.val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), other_char.val); } StringVal MaskFunctions::MaskLastN(FunctionContext* ctx, const StringVal& val, @@ -577,10 +770,10 @@ StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const StringVal& other_char) { return MaskImpl(ctx, val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), - GetFirstChar(other_char, MASKED_OTHER_CHAR)); + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); } StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, const StringVal& upper_char, const StringVal& lower_char, @@ -599,9 +792,9 @@ StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char) { return MaskImpl(ctx, val, - GetFirstChar(upper_char, MASKED_UPPERCASE), - GetFirstChar(lower_char, MASKED_LOWERCASE), - GetFirstChar(digit_char, MASKED_DIGIT), + GetFirstChar(ctx, upper_char, MASKED_UPPERCASE), + GetFirstChar(ctx, lower_char, MASKED_LOWERCASE), + GetFirstChar(ctx, digit_char, MASKED_DIGIT), other_char.val); } StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, @@ -618,6 +811,12 @@ StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, const IntVal& year_value) { return Mask(ctx, val, upper_char, lower_char, digit_char, other_char, number_char); } +StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, + const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char) { + return MaskImpl(ctx, val, upper_char.val, lower_char.val, digit_char.val, + GetFirstChar(ctx, other_char, MASKED_OTHER_CHAR)); +} StringVal MaskFunctions::Mask(FunctionContext* ctx, const StringVal& val, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char, const IntVal& day_value, diff --git a/be/src/exprs/mask-functions.h b/be/src/exprs/mask-functions.h index c01aea757..330793398 100644 --- a/be/src/exprs/mask-functions.h +++ b/be/src/exprs/mask-functions.h @@ -103,6 +103,12 @@ class MaskFunctions { const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char); + // Overload for only masking other chars. So we can support patterns like + // mask_show_first_n({col}, 4, -1, -1, -1, 'x') + static StringVal MaskShowFirstN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char); + // Overload that all masked chars are given as integers. static StringVal MaskShowFirstN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char); @@ -146,6 +152,12 @@ class MaskFunctions { const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char); + // Overload for only masking other chars. So we can support patterns like + // mask_show_last_n({col}, 4, -1, -1, -1, 'x') + static StringVal MaskShowLastN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char); + // Overload that all masked chars are given as integers. static StringVal MaskShowLastN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char); @@ -184,6 +196,12 @@ class MaskFunctions { const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char); + // Overload for only masking other chars. So we can support patterns like + // mask_first_n({col}, 4, -1, -1, -1, 'x') + static StringVal MaskFirstN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char); + // Overload that all masked chars are given as integers. static StringVal MaskFirstN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char); @@ -222,6 +240,12 @@ class MaskFunctions { const IntVal& char_count, const StringVal& upper_char, const StringVal& lower_char, const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char); + // Overload for only masking other chars. So we can support patterns like + // mask_first_n({col}, 4, -1, -1, -1, 'x') + static StringVal MaskLastN(FunctionContext* ctx, const StringVal& val, + const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char); + // Overload that all masked chars are given as integers. static StringVal MaskLastN(FunctionContext* ctx, const StringVal& val, const IntVal& char_count, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char); @@ -271,6 +295,12 @@ class MaskFunctions { const StringVal& digit_char, const IntVal& other_char, const StringVal& number_char, const IntVal& day_value, const IntVal& month_value, const IntVal& year_value); + // Overload for only masking other chars. So we can support patterns like + // mask({col}, -1, -1, -1, 'x') + static StringVal Mask(FunctionContext* ctx, const StringVal& val, + const IntVal& upper_char, const IntVal& lower_char, + const IntVal& digit_char, const StringVal& other_char); + // Overload that all masked chars are given as integers. static StringVal Mask(FunctionContext* ctx, const StringVal& val, const IntVal& upper_char, const IntVal& lower_char, const IntVal& digit_char, const IntVal& other_char, const IntVal& number_char, const IntVal& day_value, diff --git a/common/function-registry/impala_functions.py b/common/function-registry/impala_functions.py index 345706dda..dc013494b 100644 --- a/common/function-registry/impala_functions.py +++ b/common/function-registry/impala_functions.py @@ -828,6 +828,8 @@ visible_functions = [ [['mask_show_first_n'], 'STRING', ['STRING', 'INT'], 'impala::MaskFunctions::MaskShowFirstN'], [['mask_show_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING'], 'impala::MaskFunctions::MaskShowFirstN'], + [['mask_show_first_n'], 'STRING', ['STRING', 'INT', 'INT', 'INT', 'INT', 'STRING'], + 'impala::MaskFunctions::MaskShowFirstN'], [['mask_show_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING', 'INT'], 'impala::MaskFunctions::MaskShowFirstN'], [['mask_show_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'INT', 'STRING'], @@ -856,6 +858,8 @@ visible_functions = [ [['mask_show_last_n'], 'STRING', ['STRING', 'INT'], 'impala::MaskFunctions::MaskShowLastN'], [['mask_show_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING'], 'impala::MaskFunctions::MaskShowLastN'], + [['mask_show_last_n'], 'STRING', ['STRING', 'INT', 'INT', 'INT', 'INT', 'STRING'], + 'impala::MaskFunctions::MaskShowLastN'], [['mask_show_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING', 'INT'], 'impala::MaskFunctions::MaskShowLastN'], [['mask_show_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'INT', 'STRING'], @@ -886,6 +890,8 @@ visible_functions = [ [['mask_first_n'], 'STRING', ['STRING', 'INT'], 'impala::MaskFunctions::MaskFirstN'], [['mask_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING'], 'impala::MaskFunctions::MaskFirstN'], + [['mask_first_n'], 'STRING', ['STRING', 'INT', 'INT', 'INT', 'INT', 'STRING'], + 'impala::MaskFunctions::MaskFirstN'], [['mask_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING', 'INT'], 'impala::MaskFunctions::MaskFirstN'], [['mask_first_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'INT', 'STRING'], @@ -916,6 +922,8 @@ visible_functions = [ [['mask_last_n'], 'STRING', ['STRING', 'INT'], 'impala::MaskFunctions::MaskLastN'], [['mask_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING'], 'impala::MaskFunctions::MaskLastN'], + [['mask_last_n'], 'STRING', ['STRING', 'INT', 'INT', 'INT', 'INT', 'STRING'], + 'impala::MaskFunctions::MaskLastN'], [['mask_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'STRING', 'INT'], 'impala::MaskFunctions::MaskLastN'], [['mask_last_n'], 'STRING', ['STRING', 'INT', 'STRING', 'STRING', 'STRING', 'INT', 'STRING'], @@ -945,6 +953,8 @@ visible_functions = [ [['mask'], 'STRING', ['STRING'], 'impala::MaskFunctions::Mask'], [['mask'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING'], 'impala::MaskFunctions::Mask'], + [['mask'], 'STRING', ['STRING', 'INT', 'INT', 'INT', 'STRING'], + 'impala::MaskFunctions::Mask'], [['mask'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING', 'INT'], 'impala::MaskFunctions::Mask'], [['mask'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'INT', 'STRING'], diff --git a/testdata/workloads/functional-query/queries/QueryTest/utf8-string-functions.test b/testdata/workloads/functional-query/queries/QueryTest/utf8-string-functions.test index 1f7e4b845..84bab4b28 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/utf8-string-functions.test +++ b/testdata/workloads/functional-query/queries/QueryTest/utf8-string-functions.test @@ -168,3 +168,15 @@ select locate('SQL', '最快的SQL引擎跑SQL'), ---- TYPES INT,INT,INT,INT,INT,INT ==== +---- QUERY +set utf8_mode=true; +select mask('SQL引擎', 'x', 'x', 'x', 'x'), + mask_last_n('SQL引擎', 2, 'x', 'x', 'x', 'x'), + mask_show_first_n('SQL引擎', 2, 'x', 'x', 'x', 'x'), + mask_first_n('SQL引擎', 2, 'x', 'x', 'x', 'x'), + mask_show_last_n('SQL引擎', 2, 'x', 'x', 'x', 'x'); +---- RESULTS: RAW_STRING +'xxxxx','SQLxx','SQxxx','xxL引擎','xxx引擎' +---- TYPES +STRING,STRING,STRING,STRING,STRING +====